Skip to content

Commit

Permalink
fix geomean
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq authored Feb 13, 2025
1 parent 6226307 commit 3a7db8b
Showing 1 changed file with 35 additions and 6 deletions.
41 changes: 35 additions & 6 deletions experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import argparse
import matplotlib.pyplot as plt
import scipy.stats.mstats as ssm
import math
import random
import numpy as np
Expand Down Expand Up @@ -82,6 +83,36 @@
MAX_TESTED_COSTS = 999


# https://arxiv.org/pdf/1806.06403
def geomean(dataset):
dataset = np.array(dataset)
epsilon = 1e-5

dataset_nozeros = dataset[dataset > 0]

if len(dataset_nozeros) == 0:
return 0.0

geomeanNozeros = ssm.gmean(dataset_nozeros)

deltamin = 0
deltamax = geomeanNozeros - min(dataset_nozeros)
delta = (deltamin + deltamax) / 2

epsilon = epsilon * geomeanNozeros
auxExp = math.exp(np.mean(np.log(dataset_nozeros + delta))) - delta
while (auxExp - geomeanNozeros) > epsilon:
if auxExp < geomeanNozeros:
deltamin = delta
else:
deltamax = delta
delta = (deltamin + deltamax) / 2
auxExp = math.exp(np.mean(np.log(dataset_nozeros + delta))) - delta

gmeanE = math.exp(np.mean(np.log(dataset + delta))) - delta
return gmeanE


def run_command(command, description, capture_output=False, output_file=None, verbose=True, timeout=None):
print(f"=== {description} ===")
print("Running:", " ".join(command))
Expand Down Expand Up @@ -414,8 +445,7 @@ def get_avg_rel_error(ablate_type,tmp_dir, prefix, golden_values_file, binaries)
continue

try:
log_sum = sum(math.log1p(e) for e in valid_errors)
geo_mean = math.expm1(log_sum / len(valid_errors))
geo_mean = geomean(valid_errors)
errors[binary] = geo_mean
except OverflowError:
print(
Expand Down Expand Up @@ -955,8 +985,7 @@ def analyze_all_data(tmp_dir, thresholds=None):
min_runtime_ratios[threshold][prefix] = min_ratio
# print(f"Threshold: {threshold}, maximum runtime improvement for {prefix}: {(1 - min_ratio) * 100:.2f}%")

log_sum = sum(math.log1p(digits) for digits in original_digits)
geo_mean = math.expm1(log_sum / len(original_digits))
geo_mean = geomean(original_digits)
print(f"Original programs have {geo_mean:.2f} decimal digits of accuracy on average.")
print(f"Original programs have {max(original_digits):.2f} decimal digits of accuracy at most.")

Expand All @@ -966,8 +995,8 @@ def analyze_all_data(tmp_dir, thresholds=None):
ratios = min_runtime_ratios[threshold].values()
# print(f"\nThreshold: {threshold}, Number of valid runtime ratios: {len(ratios)}")
if ratios:
log_sum = sum(math.log1p(min(1, ratio)) for ratio in ratios)
geo_mean_ratio = math.expm1(log_sum / len(ratios))
log_sum = sum(math.log(min(1, ratio)) for ratio in ratios)
geo_mean_ratio = math.exp(log_sum / len(ratios))
percentage_improvement = (1 - geo_mean_ratio) * 100
overall_runtime_improvements[threshold] = percentage_improvement
else:
Expand Down

0 comments on commit 3a7db8b

Please sign in to comment.