Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Oct 23, 2024
1 parent 2242210 commit 930cfb3
Showing 1 changed file with 136 additions and 10 deletions.
146 changes: 136 additions & 10 deletions experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pickle

from tqdm import tqdm, trange
from matplotlib import rcParams

ENZYME_PATH = "/home/brant/sync/Enzyme/build/Enzyme/ClangEnzyme-15.so"
LLVM_PATH = "/home/brant/llvms/llvm15/build/bin"
Expand Down Expand Up @@ -358,17 +359,34 @@ def get_avg_rel_error(tmp_dir, prefix, golden_values_file, binaries):
return errors


def plot_results(plots_dir, prefix, budgets, runtimes, errors, example_adjusted_runtime=None, example_rel_err=None):
plot_filename = os.path.join(plots_dir, f"runtime_error_plot_{prefix[:-1]}.png")
print(f"=== Plotting results to {plot_filename} ===")
def plot_results(
plots_dir,
prefix,
budgets,
runtimes,
errors,
example_adjusted_runtime=None,
example_rel_err=None,
output_format="png",
):
print(f"=== Plotting results to {output_format.upper()} file ===")

rcParams["font.size"] = 16
rcParams["axes.titlesize"] = 18
rcParams["axes.labelsize"] = 16
rcParams["xtick.labelsize"] = 14
rcParams["ytick.labelsize"] = 14
rcParams["legend.fontsize"] = 14

plot_filename = os.path.join(plots_dir, f"runtime_error_plot_{prefix[:-1]}.{output_format}")

fig, (ax1, ax3) = plt.subplots(1, 2, figsize=(20, 8))

# First Plot: Computation Cost Budget vs Runtime and Relative Error
color_runtime = "tab:blue"
ax1.set_xlabel("Computation Cost Budget")
ax1.set_ylabel("Runtimes (seconds)", color=color_runtime)
(line1,) = ax1.plot(budgets, runtimes, marker="o", linestyle="-", label="Poseidoned Runtimes", color=color_runtime)
(line1,) = ax1.plot(budgets, runtimes, marker="o", linestyle="-", label="Optimized Runtimes", color=color_runtime)
if example_adjusted_runtime is not None:
line2 = ax1.axhline(y=example_adjusted_runtime, color=color_runtime, linestyle=":", label="Original Runtime")
ax1.tick_params(axis="y", labelcolor=color_runtime)
Expand All @@ -377,7 +395,7 @@ def plot_results(plots_dir, prefix, budgets, runtimes, errors, example_adjusted_
color_error = "tab:green"
ax2.set_ylabel("Relative Errors (%)", color=color_error)
(line3,) = ax2.plot(
budgets, errors, marker="s", linestyle="-", label="Poseidoned Relative Errors", color=color_error
budgets, errors, marker="s", linestyle="-", label="Optimized Relative Errors", color=color_error
)
if example_rel_err is not None:
line4 = ax2.axhline(y=example_rel_err, color=color_error, linestyle=":", label="Original Relative Error")
Expand Down Expand Up @@ -406,12 +424,12 @@ def plot_results(plots_dir, prefix, budgets, runtimes, errors, example_adjusted_
frameon=False,
)

# Second Plot: Pareto Front of Poseidoned Programs
# Second Plot: Pareto Front of Optimized Programs
ax3.set_xlabel("Runtimes (seconds)")
ax3.set_ylabel("Relative Errors (%)")
ax3.set_title(f"Pareto Front of Poseidoned Programs ({prefix[:-1]})")
ax3.set_title(f"Pareto Front of Optimized Programs ({prefix[:-1]})")

scatter1 = ax3.scatter(runtimes, errors, label="Poseidoned Programs", color="blue")
scatter1 = ax3.scatter(runtimes, errors, label="Optimized Programs", color="blue")

if example_adjusted_runtime is not None and example_rel_err is not None:
scatter2 = ax3.scatter(
Expand Down Expand Up @@ -536,6 +554,7 @@ def benchmark(tmp_dir, logs_dir, prefix, plots_dir):
optimized_binaries.append(output_binary)

errors_dict = get_avg_rel_error(tmp_dir, prefix, golden_values_file, optimized_binaries)
errors = []
for binary in optimized_binaries:
errors.append(errors_dict[binary])
print(f"Average rel error for {binary}: {errors_dict[binary]}")
Expand Down Expand Up @@ -563,7 +582,7 @@ def benchmark(tmp_dir, logs_dir, prefix, plots_dir):
)


def plot_from_data(tmp_dir, plots_dir, prefix):
def plot_from_data(tmp_dir, plots_dir, prefix, output_format="png"):
data_file = os.path.join(tmp_dir, f"{prefix}benchmark_data.pkl")
if not os.path.exists(data_file):
print(f"Data file {data_file} does not exist. Cannot plot.")
Expand All @@ -578,9 +597,111 @@ def plot_from_data(tmp_dir, plots_dir, prefix):
data["errors"],
example_adjusted_runtime=data["example_adjusted_runtime"],
example_rel_err=data["example_rel_err"],
output_format=output_format,
)


def analyze_all_data(tmp_dir, thresholds=None):
prefixes = []
data_list = []

for filename in os.listdir(tmp_dir):
if filename.endswith("benchmark_data.pkl"):
data_file = os.path.join(tmp_dir, filename)
with open(data_file, "rb") as f:
data = pickle.load(f)
prefix = filename[: -len("benchmark_data.pkl")]
prefixes.append(prefix)
data_list.append((prefix, data))

if not data_list:
print("No benchmark data files found in the tmp directory.")
return

print(f"Analyzing data for prefixes: {', '.join(prefixes)}\n")

if thresholds is None:
thresholds = [0, 1e-10, 1e-9, 1e-8, 1e-6, 1e-4, 1e-2, 1e-1, 0.2, 0.3, 0.4, 0.5, 0.9, 1]

max_accuracy_improvements = {} # per benchmark
runtime_ratios_per_threshold = {threshold: [] for threshold in thresholds}

for prefix, data in data_list:
budgets = data["budgets"]
runtimes = data["runtimes"]
errors = data["errors"]
example_adjusted_runtime = data["example_adjusted_runtime"]
example_rel_err = data["example_rel_err"]

if example_rel_err is None or example_rel_err <= 0:
example_digits = None
else:
example_digits = -math.log2(example_rel_err / 100)

digits_list = []
for err in errors:
if err is None or err <= 0:
digits_list.append(None)
else:
digits = -math.log2(err / 100)
digits_list.append(digits)

accuracy_improvements = []
for digits in digits_list:
if digits is not None and example_digits is not None:
improvement = digits - example_digits
accuracy_improvements.append(improvement)
else:
accuracy_improvements.append(None)

max_improvement = None
for improvement in accuracy_improvements:
if improvement is not None:
if improvement <= 0:
continue
if max_improvement is None or improvement > max_improvement:
max_improvement = improvement

if max_improvement is None:
max_accuracy_improvements[prefix] = 0.0
else:
max_accuracy_improvements[prefix] = max_improvement

for err, runtime in zip(errors, runtimes):
if err is not None and runtime is not None:
for threshold in thresholds:
if err <= threshold * 100:
runtime_ratio = runtime / example_adjusted_runtime
runtime_ratios_per_threshold[threshold].append(runtime_ratio)

overall_runtime_improvements = {}
for threshold in thresholds:
ratios = runtime_ratios_per_threshold[threshold]
if ratios:
log_sum = sum(math.log(ratio) for ratio in ratios)
geo_mean_ratio = math.exp(log_sum / len(ratios))
percentage_improvement = (1 - geo_mean_ratio) * 100
overall_runtime_improvements[threshold] = percentage_improvement
else:
overall_runtime_improvements[threshold] = None

print("Maximum accuracy improvements (in number of bits) per benchmark:")
for prefix in prefixes:
improvement = max_accuracy_improvements.get(prefix)
if improvement is not None:
print(f"{prefix}: {improvement:.2f} bits")
else:
print(f"{prefix}: No improvement")

print("\nGeometric average percentage of runtime improvements while allowing some level of relative error:")
for threshold in thresholds:
percentage_improvement = overall_runtime_improvements[threshold]
if percentage_improvement is not None:
print(f"Allowed relative error ≤ {threshold}: {percentage_improvement:.2f}% runtime improvement")
else:
print(f"Allowed relative error ≤ {threshold}: No data")


def build_with_benchmark(tmp_dir, logs_dir, plots_dir, prefix):
build_all(tmp_dir, logs_dir, prefix)
benchmark(tmp_dir, logs_dir, prefix, plots_dir)
Expand All @@ -594,6 +715,8 @@ def main():
parser.add_argument("--benchmark", action="store_true", help="Run benchmark")
parser.add_argument("--all", action="store_true", help="Build and run benchmark")
parser.add_argument("--plot-only", action="store_true", help="Plot results from existing data")
parser.add_argument("--output-format", type=str, default="png", help="Output format for plots (e.g., png, pdf)")
parser.add_argument("--analytics", action="store_true", help="Run analytics on saved data")
args = parser.parse_args()

prefix = args.prefix
Expand All @@ -618,7 +741,10 @@ def main():
benchmark(tmp_dir, logs_dir, prefix, plots_dir)
sys.exit(0)
elif args.plot_only:
plot_from_data(tmp_dir, plots_dir, prefix)
plot_from_data(tmp_dir, plots_dir, prefix, output_format=args.output_format)
sys.exit(0)
elif args.analytics:
analyze_all_data(tmp_dir)
sys.exit(0)
elif args.all:
build_with_benchmark(tmp_dir, logs_dir, plots_dir, prefix)
Expand Down

0 comments on commit 930cfb3

Please sign in to comment.