diff --git a/experiments/run.py b/experiments/run.py index efb58be..073de95 100755 --- a/experiments/run.py +++ b/experiments/run.py @@ -8,6 +8,7 @@ import re import argparse import matplotlib.pyplot as plt +import scipy.stats.mstats as ssm import math import random import numpy as np @@ -82,6 +83,36 @@ MAX_TESTED_COSTS = 999 +# https://arxiv.org/pdf/1806.06403 +def geomean(dataset): + dataset = np.array(dataset) + epsilon = 1e-5 + + dataset_nozeros = dataset[dataset > 0] + + if len(dataset_nozeros) == 0: + return 0.0 + + geomeanNozeros = ssm.gmean(dataset_nozeros) + + deltamin = 0 + deltamax = geomeanNozeros - min(dataset_nozeros) + delta = (deltamin + deltamax) / 2 + + epsilon = epsilon * geomeanNozeros + auxExp = math.exp(np.mean(np.log(dataset_nozeros + delta))) - delta + while (auxExp - geomeanNozeros) > epsilon: + if auxExp < geomeanNozeros: + deltamin = delta + else: + deltamax = delta + delta = (deltamin + deltamax) / 2 + auxExp = math.exp(np.mean(np.log(dataset_nozeros + delta))) - delta + + gmeanE = math.exp(np.mean(np.log(dataset + delta))) - delta + return gmeanE + + def run_command(command, description, capture_output=False, output_file=None, verbose=True, timeout=None): print(f"=== {description} ===") print("Running:", " ".join(command)) @@ -414,8 +445,7 @@ def get_avg_rel_error(ablate_type,tmp_dir, prefix, golden_values_file, binaries) continue try: - log_sum = sum(math.log1p(e) for e in valid_errors) - geo_mean = math.expm1(log_sum / len(valid_errors)) + geo_mean = geomean(valid_errors) errors[binary] = geo_mean except OverflowError: print( @@ -955,8 +985,7 @@ def analyze_all_data(tmp_dir, thresholds=None): min_runtime_ratios[threshold][prefix] = min_ratio # print(f"Threshold: {threshold}, maximum runtime improvement for {prefix}: {(1 - min_ratio) * 100:.2f}%") - log_sum = sum(math.log1p(digits) for digits in original_digits) - geo_mean = math.expm1(log_sum / len(original_digits)) + geo_mean = geomean(original_digits) print(f"Original programs have {geo_mean:.2f} decimal digits of accuracy on average.") print(f"Original programs have {max(original_digits):.2f} decimal digits of accuracy at most.") @@ -966,8 +995,8 @@ def analyze_all_data(tmp_dir, thresholds=None): ratios = min_runtime_ratios[threshold].values() # print(f"\nThreshold: {threshold}, Number of valid runtime ratios: {len(ratios)}") if ratios: - log_sum = sum(math.log1p(min(1, ratio)) for ratio in ratios) - geo_mean_ratio = math.expm1(log_sum / len(ratios)) + log_sum = sum(math.log(min(1, ratio)) for ratio in ratios) + geo_mean_ratio = math.exp(log_sum / len(ratios)) percentage_improvement = (1 - geo_mean_ratio) * 100 overall_runtime_improvements[threshold] = percentage_improvement else: