diff --git a/ablations/.DS_Store b/ablations/.DS_Store new file mode 100644 index 0000000..3614cf1 Binary files /dev/null and b/ablations/.DS_Store differ diff --git a/ablations/ablation.py b/ablations/ablation.py new file mode 100644 index 0000000..18eb001 --- /dev/null +++ b/ablations/ablation.py @@ -0,0 +1,935 @@ +#!/usr/bin/env python3 + +import os +import subprocess +import sys +import shutil +import re +import argparse +import matplotlib.pyplot as plt +import math +import random +import numpy as np +import json +from statistics import mean +import pickle +from tqdm import trange +from matplotlib import rcParams +from concurrent.futures import ProcessPoolExecutor, as_completed + +HOME = "/home/sbrantq" +ENZYME_PATH = os.path.join(HOME, "sync/Enzyme/build-debug/Enzyme/ClangEnzyme-16.so") +LLVM_PATH = os.path.join(HOME, "llvms/llvm16/build/bin") +CXX = os.path.join(LLVM_PATH, "clang++") + +CXXFLAGS = [ + "-O3", + "-I" + os.path.join(HOME, "include"), + "-L" + os.path.join(HOME, "lib"), + "-I/usr/include/c++/11", + "-I/usr/include/x86_64-linux-gnu/c++/11", + "-L/usr/lib/gcc/x86_64-linux-gnu/11", + "-fno-exceptions", + f"-fpass-plugin={ENZYME_PATH}", + "-Xclang", + "-load", + "-Xclang", + ENZYME_PATH, + "-lmpfr", + "-ffast-math", + "-fno-finite-math-only", + "-fuse-ld=lld", +] + +FPOPTFLAGS_BASE_TEMPLATE = [ + "-mllvm", + "--enzyme-enable-fpopt", + "-mllvm", + "--enzyme-print-herbie", + "-mllvm", + "--enzyme-print-fpopt", + "-mllvm", + "--fpopt-log-path=example.txt", + "-mllvm", + "--fpopt-target-func-regex=example", + "-mllvm", + "--fpopt-enable-solver", + "-mllvm", + "--fpopt-enable-pt", + "-mllvm", + "--fpopt-cost-dom-thres=0", + "-mllvm", + "--fpopt-acc-dom-thres=0", + "-mllvm", + "--fpopt-comp-cost-budget=0", + "-mllvm", + "--herbie-num-threads=8", + "-mllvm", + "--herbie-timeout=1000", + "-mllvm", + "--fpopt-num-samples=1024", + "-mllvm", + "--fpopt-cost-model-path=/home/sbrantq/sync/FPBench/microbm/cm.csv", + "-mllvm", + "-fpopt-cache-path=cache", +] + +SRC = "example.c" +LOGGER = "fp-logger.cpp" +NUM_RUNS = 10 +DRIVER_NUM_SAMPLES = 10000000 +LOG_NUM_SAMPLES = 10000 +MAX_TESTED_COSTS = 999 + + +def run_command(command, description, capture_output=False, output_file=None, verbose=True, timeout=None): + print(f"=== {description} ===") + print("Running:", " ".join(command)) + try: + if capture_output and output_file: + with open(output_file, "w") as f: + subprocess.check_call(command, stdout=f, stderr=subprocess.STDOUT, timeout=timeout) + elif capture_output: + result = subprocess.run(command, capture_output=True, text=True, check=True, timeout=timeout) + return result.stdout + else: + if verbose: + subprocess.check_call(command, timeout=timeout) + else: + subprocess.check_call(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=timeout) + except subprocess.TimeoutExpired: + print(f"Command '{' '.join(command)}' timed out after {timeout} seconds.") + return + except subprocess.CalledProcessError as e: + print(f"Error during: {description}") + if capture_output and output_file: + print(f"Check the output file: {output_file} for details.") + else: + print(e) + sys.exit(e.returncode) + + +def clean(tmp_dir, logs_dir, plots_dir): + print("=== Cleaning up generated files ===") + directories = [tmp_dir, logs_dir, plots_dir] + for directory in directories: + if os.path.exists(directory): + shutil.rmtree(directory) + print(f"Removed directory: {directory}") + + +def generate_example_cpp(tmp_dir, original_prefix, prefix): + script = "fpopt-original-driver-generator.py" + print(f"=== Running {script} ===") + src_prefixed = os.path.join(tmp_dir, f"{original_prefix}{SRC}") + dest_prefixed = os.path.join(tmp_dir, f"{prefix}example.cpp") + run_command( + ["python3", script, src_prefixed, dest_prefixed, "example", str(DRIVER_NUM_SAMPLES)], + f"Generating example.cpp from {SRC}", + ) + if not os.path.exists(dest_prefixed): + print(f"Failed to generate {dest_prefixed}.") + sys.exit(1) + print(f"Generated {dest_prefixed} successfully.") + + +def generate_example_logged_cpp(tmp_dir, original_prefix, prefix): + script = "fpopt-logged-driver-generator.py" + print(f"=== Running {script} ===") + src_prefixed = os.path.join(tmp_dir, f"{original_prefix}{SRC}") + dest_prefixed = os.path.join(tmp_dir, f"{prefix}example-logged.cpp") + run_command( + ["python3", script, src_prefixed, dest_prefixed, "example", str(LOG_NUM_SAMPLES)], + f"Generating example-logged.cpp from {SRC}", + ) + if not os.path.exists(dest_prefixed): + print(f"Failed to generate {dest_prefixed}.") + sys.exit(1) + print(f"Generated {dest_prefixed} successfully.") + + +def compile_example_exe(tmp_dir, prefix): + source = os.path.join(tmp_dir, f"{prefix}example.cpp") + output = os.path.join(tmp_dir, f"{prefix}example.exe") + cmd = [CXX, source] + CXXFLAGS + ["-o", output] + run_command(cmd, f"Compiling {output}") + + +def compile_example_logged_exe(tmp_dir, prefix): + source = os.path.join(tmp_dir, f"{prefix}example-logged.cpp") + output = os.path.join(tmp_dir, f"{prefix}example-logged.exe") + cmd = [CXX, source, LOGGER] + CXXFLAGS + ["-o", output] + run_command(cmd, f"Compiling {output}") + + +def generate_example_txt(tmp_dir, prefix): + exe = os.path.join(tmp_dir, f"{prefix}example-logged.exe") + output = os.path.join(tmp_dir, f"{prefix}example.txt") + if not os.path.exists(exe): + print(f"Executable {exe} not found. Cannot generate {output}.") + sys.exit(1) + with open(output, "w") as f: + print(f"=== Running {exe} to generate {output} ===") + try: + subprocess.check_call([exe], stdout=f) + except subprocess.TimeoutExpired: + print(f"Execution of {exe} timed out.") + if os.path.exists(exe): + os.remove(exe) + print(f"Removed executable {exe} due to timeout.") + return + except subprocess.CalledProcessError as e: + print(f"Error running {exe}") + sys.exit(e.returncode) + + +def compile_example_fpopt_exe(tmp_dir, prefix, fpoptflags, output="example-fpopt.exe", verbose=True): + source = os.path.join(tmp_dir, f"{prefix}example.cpp") + output_path = os.path.join(tmp_dir, f"{prefix}{output}") + cmd = [CXX, source] + CXXFLAGS + fpoptflags + ["-o", output_path] + log_path = os.path.join("logs", f"{prefix}compile_fpopt.log") + if output == "example-fpopt.exe": + run_command( + cmd, + f"Compiling {output_path} with FPOPTFLAGS", + capture_output=True, + output_file=log_path, + verbose=verbose, + ) + else: + run_command( + cmd, + f"Compiling {output_path} with FPOPTFLAGS", + verbose=verbose, + ) + + +def parse_critical_comp_costs(tmp_dir, prefix, log_path="compile_fpopt.log"): + print(f"=== Parsing critical computation costs from {log_path} ===") + full_log_path = os.path.join("logs", f"{prefix}{log_path}") + if not os.path.exists(full_log_path): + print(f"Log file {full_log_path} does not exist.") + sys.exit(1) + with open(full_log_path, "r") as f: + content = f.read() + + pattern = r"\*\*\* Critical Computation Costs \*\*\*(.*?)\*\*\* End of Critical Computation Costs \*\*\*" + match = re.search(pattern, content, re.DOTALL) + if not match: + print("Critical Computation Costs block not found in the log.") + sys.exit(1) + + costs_str = match.group(1).strip() + costs = [int(cost) for cost in costs_str.split(",") if re.fullmatch(r"-?\d+", cost.strip())] + print(f"Parsed computation costs: {costs}") + + if not costs: + print("No valid computation costs found to sample.") + sys.exit(1) + + num_to_sample = min(MAX_TESTED_COSTS, len(costs)) + + sampled_costs = random.sample(costs, num_to_sample) + + sampled_costs_sorted = sorted(sampled_costs) + + print(f"Sampled computation costs (sorted): {sampled_costs_sorted}") + + return sampled_costs_sorted + + +def measure_runtime(tmp_dir, prefix, executable, num_runs=NUM_RUNS): + print(f"=== Measuring runtime for {executable} ===") + runtimes = [] + exe_path = os.path.join(tmp_dir, f"{prefix}{executable}") + for i in trange(1, num_runs + 1): + try: + result = subprocess.run([exe_path], capture_output=True, text=True, check=True, timeout=300) + output = result.stdout + match = re.search(r"Total runtime: ([\d\.]+) seconds", output) + if match: + runtime = float(match.group(1)) + runtimes.append(runtime) + else: + print(f"Could not parse runtime from output on run {i}") + sys.exit(1) + except subprocess.TimeoutExpired: + print(f"Execution of {exe_path} timed out on run {i}") + if os.path.exists(exe_path): + os.remove(exe_path) + print(f"Removed executable {exe_path} due to timeout.") + return None + except subprocess.CalledProcessError as e: + print(f"Error running {exe_path} on run {i}") + sys.exit(e.returncode) + if runtimes: + average_runtime = mean(runtimes) + print(f"Average runtime for {prefix}{executable}: {average_runtime:.6f} seconds") + return average_runtime + else: + print(f"No successful runs for {prefix}{executable}") + return None + + +def get_values_file_path(tmp_dir, prefix, binary_name): + return os.path.join(tmp_dir, f"{prefix}{binary_name}-values.txt") + + +def generate_example_values(tmp_dir, prefix): + binary_name = "example.exe" + exe = os.path.join(tmp_dir, f"{prefix}{binary_name}") + output_values_file = get_values_file_path(tmp_dir, prefix, binary_name) + cmd = [exe, "--output-path", output_values_file] + run_command(cmd, f"Generating function values from {binary_name}", verbose=False, timeout=300) + + +def generate_values(tmp_dir, prefix, binary_name): + exe = os.path.join(tmp_dir, f"{prefix}{binary_name}") + values_file = get_values_file_path(tmp_dir, prefix, binary_name) + cmd = [exe, "--output-path", values_file] + run_command(cmd, f"Generating function values from {binary_name}", verbose=False, timeout=300) + + +def compile_golden_exe(tmp_dir, prefix): + source = os.path.join(tmp_dir, f"{prefix}golden.cpp") + output = os.path.join(tmp_dir, f"{prefix}golden.exe") + cmd = [CXX, source] + CXXFLAGS + ["-o", output] + run_command(cmd, f"Compiling {output}") + + +def generate_golden_values(tmp_dir, original_prefix, prefix): + script = "fpopt-golden-driver-generator.py" + src_prefixed = os.path.join(tmp_dir, f"{original_prefix}{SRC}") + dest_prefixed = os.path.join(tmp_dir, f"{prefix}golden.cpp") + cur_prec = 128 + max_prec = 4096 + PREC_step = 128 + prev_output = None + output_values_file = get_values_file_path(tmp_dir, prefix, "golden.exe") + while cur_prec <= max_prec: + run_command( + ["python3", script, src_prefixed, dest_prefixed, str(cur_prec), "example", str(DRIVER_NUM_SAMPLES)], + f"Generating golden.cpp with PREC={cur_prec}", + ) + if not os.path.exists(dest_prefixed): + print(f"Failed to generate {dest_prefixed}.") + sys.exit(1) + print(f"Generated {dest_prefixed} successfully.") + + compile_golden_exe(tmp_dir, prefix) + + exe = os.path.join(tmp_dir, f"{prefix}golden.exe") + cmd = [exe, "--output-path", output_values_file] + run_command(cmd, f"Generating golden values with PREC={cur_prec}", verbose=False) + + if not os.path.exists(output_values_file): + print(f"Failed to generate golden values at PREC={cur_prec} due to timeout.") + return + + with open(output_values_file, "r") as f: + output = f.read() + + if output == prev_output: + print(f"Golden values converged at PREC={cur_prec}") + break + else: + prev_output = output + cur_prec += PREC_step + else: + print(f"Failed to converge golden values up to PREC={max_prec}") + sys.exit(1) + + +def get_avg_rel_error(tmp_dir, prefix, golden_values_file, binaries): + with open(golden_values_file, "r") as f: + golden_values = [float(line.strip()) for line in f] + + errors = {} + for binary in binaries: + values_file = get_values_file_path(tmp_dir, prefix, binary) + if not os.path.exists(values_file): + print(f"Values file {values_file} does not exist. Skipping error calculation for {binary}.") + errors[binary] = None + continue + with open(values_file, "r") as f: + values = [float(line.strip()) for line in f] + if len(values) != len(golden_values): + print(f"Number of values in {values_file} does not match golden values") + sys.exit(1) + + valid_errors = [] + for v, g in zip(values, golden_values): + if math.isnan(v) or math.isnan(g): + continue + if g == 0: + continue + error = abs((v - g) / g) * 100 + valid_errors.append(error) + + if not valid_errors: + print(f"No valid data to compute rel error for binary {binary}. Setting rel error to None.") + errors[binary] = None + continue + + try: + log_sum = sum(math.log1p(e) for e in valid_errors) + geo_mean = math.expm1(log_sum / len(valid_errors)) + errors[binary] = geo_mean + except OverflowError: + print( + f"Overflow error encountered while computing geometric mean for binary {binary}. Setting rel error to None." + ) + errors[binary] = None + except ZeroDivisionError: + print(f"No valid errors to compute geometric mean for binary {binary}. Setting rel error to None.") + errors[binary] = None + + return errors + + +def build_all(tmp_dir, logs_dir, original_prefix, prefix, fpoptflags, example_txt_path): + generate_example_cpp(tmp_dir, original_prefix, prefix) + generate_example_logged_cpp(tmp_dir, original_prefix, prefix) + compile_example_exe(tmp_dir, prefix) + compile_example_logged_exe(tmp_dir, prefix) + if not os.path.exists(example_txt_path): + generate_example_txt(tmp_dir, original_prefix) + compile_example_fpopt_exe(tmp_dir, prefix, fpoptflags, output="example-fpopt.exe") + print("=== Initial build process completed successfully ===") + + +def process_cost(args): + cost, tmp_dir, prefix, fpoptflags = args + + print(f"\n=== Processing computation cost budget: {cost} ===") + fpoptflags_cost = [] + for flag in fpoptflags: + if flag.startswith("--fpopt-comp-cost-budget="): + fpoptflags_cost.append(f"--fpopt-comp-cost-budget={cost}") + else: + fpoptflags_cost.append(flag) + + output_binary = f"example-fpopt-{cost}.exe" + + compile_example_fpopt_exe(tmp_dir, prefix, fpoptflags_cost, output=output_binary, verbose=False) + generate_values(tmp_dir, prefix, output_binary) + + return cost, output_binary + + +def benchmark(tmp_dir, logs_dir, original_prefix, prefix, plots_dir, fpoptflags, num_parallel=1): + costs = parse_critical_comp_costs(tmp_dir, prefix) + + original_avg_runtime = measure_runtime(tmp_dir, prefix, "example.exe", NUM_RUNS) + original_runtime = original_avg_runtime + + if original_runtime is None: + print("Original binary timed out. Proceeding as if it doesn't exist.") + return + + generate_example_values(tmp_dir, prefix) + + generate_golden_values(tmp_dir, original_prefix, prefix) + + golden_values_file = get_values_file_path(tmp_dir, prefix, "golden.exe") + example_binary = "example.exe" + rel_errs_example = get_avg_rel_error(tmp_dir, prefix, golden_values_file, [example_binary]) + rel_err_example = rel_errs_example[example_binary] + print(f"Average Rel Error for {prefix}example.exe: {rel_err_example}") + + data_tuples = [] + + args_list = [(cost, tmp_dir, prefix, fpoptflags) for cost in costs] + + if num_parallel == 1: + for args in args_list: + cost, output_binary = process_cost(args) + data_tuples.append((cost, output_binary)) + else: + with ProcessPoolExecutor(max_workers=num_parallel) as executor: + future_to_cost = {executor.submit(process_cost, args): args[0] for args in args_list} + for future in as_completed(future_to_cost): + cost = future_to_cost[future] + try: + cost_result, output_binary = future.result() + data_tuples.append((cost_result, output_binary)) + except Exception as exc: + print(f"Cost {cost} generated an exception: {exc}") + + data_tuples_sorted = sorted(data_tuples, key=lambda x: x[0]) + sorted_budgets, sorted_optimized_binaries = zip(*data_tuples_sorted) if data_tuples_sorted else ([], []) + + sorted_runtimes = [] + for cost, output_binary in zip(sorted_budgets, sorted_optimized_binaries): + avg_runtime = measure_runtime(tmp_dir, prefix, output_binary, NUM_RUNS) + if avg_runtime is not None: + sorted_runtimes.append(avg_runtime) + else: + print(f"Skipping cost {cost} due to runtime measurement failure.") + sorted_runtimes.append(None) + + errors_dict = get_avg_rel_error(tmp_dir, prefix, golden_values_file, sorted_optimized_binaries) + sorted_errors = [] + for binary in sorted_optimized_binaries: + sorted_errors.append(errors_dict.get(binary)) + print(f"Average rel error for {binary}: {errors_dict.get(binary)}") + + sorted_budgets = list(sorted_budgets) + sorted_runtimes = list(sorted_runtimes) + sorted_errors = list(sorted_errors) + + data = { + "budgets": sorted_budgets, + "runtimes": sorted_runtimes, + "errors": sorted_errors, + "original_runtime": original_runtime, + "original_error": rel_err_example, + } + + table_json_path = os.path.join("cache", "table.json") + if os.path.exists(table_json_path): + with open(table_json_path, "r") as f: + table_data = json.load(f) + if "costToAccuracyMap" in table_data: + predicted_costs = [] + predicted_errors = [] + for k, v in table_data["costToAccuracyMap"].items(): + try: + cost_val = float(k) + err_val = float(v) + predicted_costs.append(cost_val) + predicted_errors.append(err_val) + except ValueError: + pass + pred_sorted_indices = np.argsort(predicted_costs) + predicted_costs = np.array(predicted_costs)[pred_sorted_indices].tolist() + predicted_errors = np.array(predicted_errors)[pred_sorted_indices].tolist() + data["predicted_data"] = {"costs": predicted_costs, "errors": predicted_errors} + + data_file = os.path.join(tmp_dir, f"{prefix}benchmark_data.pkl") + with open(data_file, "wb") as f: + pickle.dump(data, f) + print(f"Benchmark data saved to {data_file}") + + return data + + +def build_with_benchmark( + tmp_dir, logs_dir, plots_dir, original_prefix, prefix, fpoptflags, example_txt_path, num_parallel=1 +): + build_all(tmp_dir, logs_dir, original_prefix, prefix, fpoptflags, example_txt_path) + data = benchmark(tmp_dir, logs_dir, original_prefix, prefix, plots_dir, fpoptflags, num_parallel) + return data + + +def remove_cache_dir(): + cache_dir = "cache" + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + print("=== Removed existing cache directory ===") + + +def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_format="png", show_prediction=False): + ablation_data_file = os.path.join(tmp_dir, f"{prefix}ablation-widen-range.pkl") + if not os.path.exists(ablation_data_file): + print(f"Ablation data file {ablation_data_file} does not exist. Cannot plot.") + sys.exit(1) + with open(ablation_data_file, "rb") as f: + all_data = pickle.load(f) + + if not all_data: + print("No data to plot.") + sys.exit(1) + + rcParams["font.size"] = 20 + rcParams["axes.titlesize"] = 24 + rcParams["axes.labelsize"] = 20 + rcParams["xtick.labelsize"] = 18 + rcParams["ytick.labelsize"] = 18 + rcParams["legend.fontsize"] = 18 + + if show_prediction and any("predicted_data" in d for d in all_data.values()): + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8)) + else: + fig, ax1 = plt.subplots(1, 1, figsize=(10, 8)) + ax2 = None + + # Extract original runtime/error from the first scenario + first_key = next(iter(all_data)) + original_runtime = all_data[first_key]["original_runtime"] + original_error = all_data[first_key]["original_error"] + + # Plot original program once + ax1.scatter(original_runtime, original_error, marker="x", color="black", s=100, label="Original Program") + + colors = ["blue", "green", "red", "purple", "orange", "brown", "pink", "gray", "olive", "cyan"] + color_iter = iter(colors) + + for X, data in sorted(all_data.items()): + budgets = data["budgets"] + runtimes = data["runtimes"] + errors = data["errors"] + + data_points = list(zip(runtimes, errors)) + filtered_data = [(r, e) for r, e in data_points if r is not None and e is not None] + if not filtered_data: + print(f"No valid data to plot for widen-range={X}.") + continue + runtimes_filtered, errors_filtered = zip(*filtered_data) + color = next(color_iter) + ax1.scatter(runtimes_filtered, errors_filtered, label=f"widen-range={X}", color=color) + points = np.array(filtered_data) + sorted_indices = np.argsort(points[:, 0]) + sorted_points = points[sorted_indices] + + # Compute pareto front + pareto_front = [sorted_points[0]] + for point in sorted_points[1:]: + if point[1] < pareto_front[-1][1]: + pareto_front.append(point) + + pareto_front = np.array(pareto_front) + ax1.step( + pareto_front[:, 0], + pareto_front[:, 1], + where="post", + linestyle="-", + color=color, + ) + + if show_prediction and "predicted_data" in data and ax2 is not None: + p_costs = data["predicted_data"]["costs"] + p_errors = data["predicted_data"]["errors"] + if p_costs and p_errors: + ax2.scatter(p_costs, p_errors, label=f"Prediction (widen-range={X})", color=color) + pred_points = np.column_stack((p_costs, p_errors)) + pred_sorted_indices = np.argsort(pred_points[:, 0]) + pred_sorted_points = pred_points[pred_sorted_indices] + + pred_pareto = [pred_sorted_points[0]] + for pt in pred_sorted_points[1:]: + if pt[1] < pred_pareto[-1][1]: + pred_pareto.append(pt) + pred_pareto = np.array(pred_pareto) + + ax2.step( + pred_pareto[:, 0], + pred_pareto[:, 1], + where="post", + linestyle="-", + color=color, + ) + + ax1.set_xlabel("Runtimes (seconds)") + ax1.set_ylabel("Relative Errors (%)") + ax1.set_title("Pareto Fronts for Different widen-range Values") + ax1.set_yscale("symlog", linthresh=1e-15) + ax1.set_ylim(bottom=0) + ax1.legend() + ax1.grid(True) + + if ax2 is not None: + ax2.set_xlabel("Cost Budget") + ax2.set_ylabel("Predicted Error") + ax2.set_title("Predicted Pareto Fronts") + ax2.set_yscale("symlog", linthresh=1e-15) + ax2.set_ylim(bottom=-1e-13) + ax2.grid(True) + + plot_filename = os.path.join(plots_dir, f"{prefix}ablation_widen_range_pareto_front.{output_format}") + plt.savefig(plot_filename, bbox_inches="tight", dpi=300) + plt.close() + print(f"Ablation plot saved to {plot_filename}") + + +def plot_ablation_results_cost_model( + tmp_dir, plots_dir, original_prefix, prefix, output_format="png", show_prediction=False +): + ablation_data_file = os.path.join(tmp_dir, f"{prefix}ablation-cost-model.pkl") + if not os.path.exists(ablation_data_file): + print(f"Ablation data file {ablation_data_file} does not exist. Cannot plot.") + sys.exit(1) + with open(ablation_data_file, "rb") as f: + all_data = pickle.load(f) + + if not all_data: + print("No data to plot.") + sys.exit(1) + + rcParams["font.size"] = 20 + rcParams["axes.titlesize"] = 24 + rcParams["axes.labelsize"] = 20 + rcParams["xtick.labelsize"] = 18 + rcParams["ytick.labelsize"] = 18 + rcParams["legend.fontsize"] = 18 + + if show_prediction and any("predicted_data" in d for d in all_data.values()): + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8)) + else: + fig, ax1 = plt.subplots(1, 1, figsize=(10, 8)) + ax2 = None + + # Extract original runtime/error from the first scenario + first_key = next(iter(all_data)) + original_runtime = all_data[first_key]["original_runtime"] + original_error = all_data[first_key]["original_error"] + + # Plot original program once + ax1.scatter(original_runtime, original_error, marker="x", color="black", s=100, label="Original Program") + + colors = ["blue", "green"] + labels = ["Custom Cost Model", "TTI Cost Model"] + + for idx, key in enumerate(["with_cost_model", "without_cost_model"]): + data = all_data[key] + budgets = data["budgets"] + runtimes = data["runtimes"] + errors = data["errors"] + + data_points = list(zip(runtimes, errors)) + filtered_data = [(r, e) for r, e in data_points if r is not None and e is not None] + if not filtered_data: + print(f"No valid data to plot for {labels[idx]}.") + continue + runtimes_filtered, errors_filtered = zip(*filtered_data) + color = colors[idx] + ax1.scatter(runtimes_filtered, errors_filtered, label=labels[idx], color=color) + points = np.array(filtered_data) + sorted_indices = np.argsort(points[:, 0]) + sorted_points = points[sorted_indices] + + pareto_front = [sorted_points[0]] + for point in sorted_points[1:]: + if point[1] < pareto_front[-1][1]: + pareto_front.append(point) + + pareto_front = np.array(pareto_front) + + ax1.step( + pareto_front[:, 0], + pareto_front[:, 1], + where="post", + linestyle="-", + color=color, + ) + + if show_prediction and "predicted_data" in data and ax2 is not None: + p_costs = data["predicted_data"]["costs"] + p_errors = data["predicted_data"]["errors"] + if p_costs and p_errors: + ax2.scatter(p_costs, p_errors, label=f"{labels[idx]}", color=color) + pred_points = np.column_stack((p_costs, p_errors)) + pred_sorted_indices = np.argsort(pred_points[:, 0]) + pred_sorted_points = pred_points[pred_sorted_indices] + + pred_pareto = [pred_sorted_points[0]] + for pt in pred_sorted_points[1:]: + if pt[1] < pred_pareto[-1][1]: + pred_pareto.append(pt) + pred_pareto = np.array(pred_pareto) + + ax2.step( + pred_pareto[:, 0], + pred_pareto[:, 1], + where="post", + linestyle="-", + color=color, + ) + + ax1.set_xlabel("Runtimes (seconds)") + ax1.set_ylabel("Relative Errors (%)") + ax1.set_title("Pareto Fronts for Cost Model Ablation") + ax1.set_yscale("symlog", linthresh=1e-14) + ax1.set_ylim(bottom=-1e-14) + ax1.legend() + ax1.grid(True) + + if ax2 is not None: + ax2.set_xlabel("Cost Budget") + ax2.set_ylabel("Predicted Error") + ax2.set_title("Predicted Pareto Front") + ax2.set_yscale("symlog", linthresh=1e-14) + ax2.set_ylim(bottom=-1e-14) + ax2.legend() + ax2.grid(True) + + plot_filename = os.path.join(plots_dir, f"{prefix}ablation_cost_model_pareto_front.{output_format}") + plt.savefig(plot_filename, bbox_inches="tight", dpi=300) + plt.close() + print(f"Ablation plot saved to {plot_filename}") + + +def remove_mllvm_flag(flags_list, flag_prefix): + new_flags = [] + i = 0 + while i < len(flags_list): + if flags_list[i] == "-mllvm" and i + 1 < len(flags_list) and flags_list[i + 1].startswith(flag_prefix): + i += 2 + else: + new_flags.append(flags_list[i]) + i += 1 + return new_flags + + +def main(): + parser = argparse.ArgumentParser(description="Run the ablation study with widen-range parameter.") + parser.add_argument("--prefix", type=str, required=True, help="Prefix for intermediate files (e.g., rosa-ex23-)") + parser.add_argument("--clean", action="store_true", help="Clean up generated files") + parser.add_argument("--plot-only", action="store_true", help="Plot results from existing data") + parser.add_argument("--output-format", type=str, default="png", help="Output format for plots (e.g., png, pdf)") + parser.add_argument( + "--num-parallel", type=int, default=16, help="Number of parallel processes to use (default: 16)" + ) + parser.add_argument( + "--ablation-type", + type=str, + choices=["widen-range", "cost-model"], + default="widen-range", + help="Type of ablation study to perform (default: widen-range)", + ) + parser.add_argument("--show-prediction", action="store_true", help="Show predicted results in a second subplot") + + args = parser.parse_args() + + original_prefix = args.prefix + if not original_prefix.endswith("-"): + original_prefix += "-" + + tmp_dir = "tmp" + logs_dir = "logs" + plots_dir = "plots" + + os.makedirs(tmp_dir, exist_ok=True) + os.makedirs(logs_dir, exist_ok=True) + os.makedirs(plots_dir, exist_ok=True) + + example_txt_path = os.path.join(tmp_dir, f"{original_prefix}example.txt") + + if args.clean: + clean(tmp_dir, logs_dir, plots_dir) + sys.exit(0) + elif args.plot_only: + if args.ablation_type == "widen-range": + plot_ablation_results( + tmp_dir, + plots_dir, + original_prefix, + original_prefix, + args.output_format, + show_prediction=args.show_prediction, + ) + elif args.ablation_type == "cost-model": + plot_ablation_results_cost_model( + tmp_dir, + plots_dir, + original_prefix, + original_prefix, + args.output_format, + show_prediction=args.show_prediction, + ) + sys.exit(0) + else: + if not os.path.exists(example_txt_path): + generate_example_logged_cpp(tmp_dir, original_prefix, original_prefix) + compile_example_logged_exe(tmp_dir, original_prefix) + generate_example_txt(tmp_dir, original_prefix) + + if args.ablation_type == "widen-range": + # widen_ranges = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0, math.inf] + widen_ranges = [0, 1e-9, 1e-6, 0.001, 1.0, 1000.0, 1e6, 1e9] + all_data = {} + for X in widen_ranges: + print(f"=== Running ablation study with widen-range={X} ===") + remove_cache_dir() + FPOPTFLAGS_BASE = FPOPTFLAGS_BASE_TEMPLATE.copy() + for idx, flag in enumerate(FPOPTFLAGS_BASE): + if flag.startswith("--fpopt-log-path="): + FPOPTFLAGS_BASE[idx] = f"--fpopt-log-path={example_txt_path}" + FPOPTFLAGS_BASE.extend(["-mllvm", f"--fpopt-widen-range={X}"]) + + prefix_with_x = f"{original_prefix}abl-widen-range-{X}-" + + data = build_with_benchmark( + tmp_dir, + logs_dir, + plots_dir, + original_prefix, + prefix_with_x, + FPOPTFLAGS_BASE, + example_txt_path, + num_parallel=args.num_parallel, + ) + + all_data[X] = data + + ablation_data_file = os.path.join(tmp_dir, f"{original_prefix}ablation-widen-range.pkl") + with open(ablation_data_file, "wb") as f: + pickle.dump(all_data, f) + print(f"Ablation data saved to {ablation_data_file}") + + plot_ablation_results( + tmp_dir, + plots_dir, + original_prefix, + original_prefix, + args.output_format, + show_prediction=args.show_prediction, + ) + + if args.ablation_type == "cost-model": + print("=== Running cost-model ablation study ===") + remove_cache_dir() + FPOPTFLAGS_WITH_CM = FPOPTFLAGS_BASE_TEMPLATE.copy() + for idx, flag in enumerate(FPOPTFLAGS_WITH_CM): + if flag.startswith("--fpopt-log-path="): + FPOPTFLAGS_WITH_CM[idx] = f"--fpopt-log-path={example_txt_path}" + prefix_with_cm = f"{original_prefix}abl-with-cost-model-" + + data_with_cm = build_with_benchmark( + tmp_dir, + logs_dir, + plots_dir, + original_prefix, + prefix_with_cm, + FPOPTFLAGS_WITH_CM, + example_txt_path, + num_parallel=args.num_parallel, + ) + + remove_cache_dir() + FPOPTFLAGS_NO_CM = remove_mllvm_flag(FPOPTFLAGS_WITH_CM, "--fpopt-cost-model-path=") + prefix_without_cm = f"{original_prefix}abl-without-cost-model-" + + data_without_cm = build_with_benchmark( + tmp_dir, + logs_dir, + plots_dir, + original_prefix, + prefix_without_cm, + FPOPTFLAGS_NO_CM, + example_txt_path, + num_parallel=args.num_parallel, + ) + + all_data = { + "with_cost_model": data_with_cm, + "without_cost_model": data_without_cm, + } + + ablation_data_file = os.path.join(tmp_dir, f"{original_prefix}ablation-cost-model.pkl") + with open(ablation_data_file, "wb") as f: + pickle.dump(all_data, f) + print(f"Ablation data saved to {ablation_data_file}") + + plot_ablation_results_cost_model( + tmp_dir, + plots_dir, + original_prefix, + original_prefix, + args.output_format, + show_prediction=args.show_prediction, + ) + + +if __name__ == "__main__": + main() diff --git a/ablations/example.c b/ablations/example.c new file mode 100644 index 0000000..eb6a6f1 --- /dev/null +++ b/ablations/example.c @@ -0,0 +1,13 @@ +#include +#include +#define TRUE 1 +#define FALSE 0 + +// ## PRE v: 20, 20000 +// ## PRE T: -30, 50 +// ## PRE u: -100, 100 +__attribute__((noinline)) +double example(double u, double v, double T) { + double t1 = 331.4 + (0.6 * T); + return (-t1 * v) / ((t1 + u) * (t1 + u)); +} \ No newline at end of file diff --git a/ablations/fp-logger.cpp b/ablations/fp-logger.cpp new file mode 100644 index 0000000..545e31d --- /dev/null +++ b/ablations/fp-logger.cpp @@ -0,0 +1,165 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fp-logger.hpp" + +class ValueInfo { +public: + double minRes = std::numeric_limits::max(); + double maxRes = std::numeric_limits::lowest(); + std::vector minOperands; + std::vector maxOperands; + unsigned executions = 0; + double logSum = 0.0; + unsigned logCount = 0; + + void update(double res, const double *operands, unsigned numOperands) { + minRes = std::min(minRes, res); + maxRes = std::max(maxRes, res); + if (minOperands.empty()) { + minOperands.resize(numOperands, std::numeric_limits::max()); + maxOperands.resize(numOperands, std::numeric_limits::lowest()); + } + for (unsigned i = 0; i < numOperands; ++i) { + minOperands[i] = std::min(minOperands[i], operands[i]); + maxOperands[i] = std::max(maxOperands[i], operands[i]); + } + ++executions; + + if (!std::isnan(res)) { + logSum += std::log1p(std::fabs(res)); + ++logCount; + } + } + + double getGeometricAverage() const { + if (logCount == 0) { + return 0.; + } + return std::expm1(logSum / logCount); + } +}; + +class ErrorInfo { +public: + double minErr = std::numeric_limits::max(); + double maxErr = std::numeric_limits::lowest(); + + void update(double err) { + minErr = std::min(minErr, err); + maxErr = std::max(maxErr, err); + } +}; + +class GradInfo { +public: + double logSum = 0.0; + unsigned count = 0; + + void update(double grad) { + if (!std::isnan(grad)) { + logSum += std::log1p(std::fabs(grad)); + ++count; + } + } + + double getGeometricAverage() const { + if (count == 0) { + return 0.; + } + return std::expm1(logSum / count); + } +}; + +class Logger { +private: + std::unordered_map valueInfo; + std::unordered_map errorInfo; + std::unordered_map gradInfo; + +public: + void updateValue(const std::string &id, double res, unsigned numOperands, + const double *operands) { + auto &info = valueInfo.emplace(id, ValueInfo()).first->second; + info.update(res, operands, numOperands); + } + + void updateError(const std::string &id, double err) { + auto &info = errorInfo.emplace(id, ErrorInfo()).first->second; + info.update(err); + } + + void updateGrad(const std::string &id, double grad) { + auto &info = gradInfo.emplace(id, GradInfo()).first->second; + info.update(grad); + } + + void print() const { + std::cout << std::scientific + << std::setprecision(std::numeric_limits::max_digits10); + + for (const auto &pair : valueInfo) { + const auto &id = pair.first; + const auto &info = pair.second; + std::cout << "Value:" << id << "\n"; + std::cout << "\tMinRes = " << info.minRes << "\n"; + std::cout << "\tMaxRes = " << info.maxRes << "\n"; + std::cout << "\tExecutions = " << info.executions << "\n"; + std::cout << "\tGeometric Average = " << info.getGeometricAverage() + << "\n"; + for (unsigned i = 0; i < info.minOperands.size(); ++i) { + std::cout << "\tOperand[" << i << "] = [" << info.minOperands[i] << ", " + << info.maxOperands[i] << "]\n"; + } + } + + for (const auto &pair : errorInfo) { + const auto &id = pair.first; + const auto &info = pair.second; + std::cout << "Error:" << id << "\n"; + std::cout << "\tMinErr = " << info.minErr << "\n"; + std::cout << "\tMaxErr = " << info.maxErr << "\n"; + } + + for (const auto &pair : gradInfo) { + const auto &id = pair.first; + const auto &info = pair.second; + std::cout << "Grad:" << id << "\n"; + std::cout << "\tGrad = " << info.getGeometricAverage() << "\n"; + } + } +}; + +Logger *logger = nullptr; + +void initializeLogger() { logger = new Logger(); } + +void destroyLogger() { + delete logger; + logger = nullptr; +} + +void printLogger() { logger->print(); } + +void enzymeLogError(const char *id, double err) { + assert(logger && "Logger is not initialized"); + logger->updateError(id, err); +} + +void enzymeLogGrad(const char *id, double grad) { + assert(logger && "Logger is not initialized"); + logger->updateGrad(id, grad); +} + +void enzymeLogValue(const char *id, double res, unsigned numOperands, + double *operands) { + assert(logger && "Logger is not initialized"); + logger->updateValue(id, res, numOperands, operands); +} \ No newline at end of file diff --git a/ablations/fp-logger.hpp b/ablations/fp-logger.hpp new file mode 100644 index 0000000..657aa94 --- /dev/null +++ b/ablations/fp-logger.hpp @@ -0,0 +1,8 @@ +void initializeLogger(); +void destroyLogger(); +void printLogger(); + +void enzymeLogError(const char *id, double err); +void enzymeLogGrad(const char *id, double grad); +void enzymeLogValue(const char *id, double res, unsigned numOperands, + double *operands); diff --git a/ablations/fpopt-baseline-generator.py b/ablations/fpopt-baseline-generator.py new file mode 100644 index 0000000..52c35d8 --- /dev/null +++ b/ablations/fpopt-baseline-generator.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 + +import os +import sys +import re +import numpy as np + +DEFAULT_NUM_SAMPLES = 10000 +DEFAULT_REGEX = "ex\\d+" + +np.random.seed(42) + + +def parse_bound(bound): + if "/" in bound: + numerator, denominator = map(float, bound.split("/")) + return numerator / denominator + return float(bound) + + +def parse_c_file(filepath, func_regex): + with open(filepath, "r") as file: + content = file.read() + + pattern = re.compile(rf"(?s)(// ## PRE(?:.*?\n)+?)\s*([\w\s\*]+?)\s+({func_regex})\s*\(([^)]*)\)") + + matches = pattern.findall(content) + + if not matches: + exit(f"No functions found with the regex: {func_regex}") + + functions = [] + + for comments, return_type, func_name, params in matches: + param_comments = re.findall(r"// ## PRE (\w+):\s*([-+.\d/]+),\s*([-+.\d/]+)", comments) + bounds = { + name: { + "min": parse_bound(min_val), + "max": parse_bound(max_val), + } + for name, min_val, max_val in param_comments + } + params = [param.strip() for param in params.split(",") if param.strip()] + functions.append((func_name, bounds, params, return_type.strip())) + + return functions + + +def create_baseline_functions(functions): + baseline_code = [] + for func_name, bounds, params, return_type in functions: + param_list = ", ".join(params) + baseline_func_name = f"baseline_{func_name}" + baseline_code.append(f"__attribute__((noinline))\n{return_type} {baseline_func_name}({param_list}) {{") + baseline_code.append(" return 42.0;") + baseline_code.append("}") + baseline_code.append("") + return "\n".join(baseline_code) + + +def create_baseline_driver_function(functions, num_samples_per_func): + driver_code = [ + "#include ", + "#include ", + "#include ", + "#include ", + ] + + driver_code.append("#include ") + driver_code.append("#include ") + driver_code.append("#include ") + driver_code.append("#include ") + + driver_code.append("") + driver_code.append("int main(int argc, char* argv[]) {") + driver_code.append(' std::string output_path = "";') + driver_code.append("") + driver_code.append(" for (int i = 1; i < argc; ++i) {") + driver_code.append(' if (std::strcmp(argv[i], "--output-path") == 0) {') + driver_code.append(" if (i + 1 < argc) {") + driver_code.append(" output_path = argv[i + 1];") + driver_code.append(" i++;") + driver_code.append(" } else {") + driver_code.append(' std::cerr << "Error: --output-path requires a path argument." << std::endl;') + driver_code.append(" return 1;") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append("") + driver_code.append(" bool save_outputs = !output_path.empty();") + driver_code.append("") + driver_code.append(" std::mt19937 gen(42);") + driver_code.append("") + driver_code.append(" std::ofstream ofs;") + driver_code.append(" if (save_outputs) {") + driver_code.append(" ofs.open(output_path);") + driver_code.append(" if (!ofs) {") + driver_code.append(' std::cerr << "Failed to open output file: " << output_path << std::endl;') + driver_code.append(" return 1;") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append("") + + for func_name, bounds, params, return_type in functions: + for param in params: + param_tokens = param.strip().split() + if len(param_tokens) >= 2: + param_name = param_tokens[-1] + else: + exit(f"Cannot parse parameter: {param}") + try: + min_val = bounds[param_name]["min"] + max_val = bounds[param_name]["max"] + except KeyError: + exit( + f"WARNING: Bounds not found for {param_name} in function {func_name}, manually specify the bounds." + ) + dist_name = f"{func_name}_{param_name}_dist" + driver_code.append(f" std::uniform_real_distribution<{return_type}> {dist_name}({min_val}, {max_val});") + driver_code.append("") + + driver_code.append(" double sum = 0.;") + driver_code.append("") + + driver_code.append(" auto start_time = std::chrono::high_resolution_clock::now();") + driver_code.append("") + + for func_name, bounds, params, return_type in functions: + baseline_func_name = f"baseline_{func_name}" + driver_code.append(f" for (int i = 0; i < {num_samples_per_func}; ++i) {{") + + call_params = [] + for param in params: + param_tokens = param.strip().split() + if len(param_tokens) >= 2: + param_name = param_tokens[-1] + else: + exit(f"Cannot parse parameter: {param}") + dist_name = f"{func_name}_{param_name}_dist" + param_value = f"{dist_name}(gen)" + call_params.append(param_value) + + driver_code.append(f" double res = {baseline_func_name}({', '.join(call_params)});") + driver_code.append(" sum += res;") + + driver_code.append(" if (save_outputs) {") + driver_code.append( + ' ofs << std::setprecision(std::numeric_limits::digits10 + 1) << res << "\\n";' + ) + driver_code.append(" }") + driver_code.append(" }") + driver_code.append("") + + driver_code.append(' std::cout << "Sum: " << sum << std::endl;') + driver_code.append(" auto end_time = std::chrono::high_resolution_clock::now();") + driver_code.append(" std::chrono::duration elapsed = end_time - start_time;") + driver_code.append(' std::cout << "Total runtime: " << elapsed.count() << " seconds\\n";') + driver_code.append("") + + driver_code.append(" if (save_outputs) {") + driver_code.append(" ofs.close();") + driver_code.append(" }") + driver_code.append("") + + driver_code.append(" return 0;") + driver_code.append("}") + return "\n".join(driver_code) + + +def main(): + if len(sys.argv) < 3: + exit( + "Usage: fpopt-baseline-generator.py [func_regex] [num_samples_per_func (default: 10000)]" + ) + + source_path = sys.argv[1] + dest_path = sys.argv[2] + func_regex = sys.argv[3] if len(sys.argv) > 3 else DEFAULT_REGEX + num_samples_per_func = int(sys.argv[4]) if len(sys.argv) > 4 else DEFAULT_NUM_SAMPLES + + functions = parse_c_file(source_path, func_regex) + baseline_functions_code = create_baseline_functions(functions) + driver_code = create_baseline_driver_function(functions, num_samples_per_func) + + with open(dest_path, "w") as new_file: + new_file.write(baseline_functions_code) + new_file.write("\n\n") + new_file.write(driver_code) + + print(f"Baseline code written to the new file: {dest_path}") + + +if __name__ == "__main__": + main() diff --git a/ablations/fpopt-golden-driver-generator.py b/ablations/fpopt-golden-driver-generator.py new file mode 100644 index 0000000..0e8b4ff --- /dev/null +++ b/ablations/fpopt-golden-driver-generator.py @@ -0,0 +1,197 @@ +import os +import sys +import re + +DEFAULT_NUM_SAMPLES = 100000 +DEFAULT_REGEX = "ex\\d+" + + +def parse_bound(bound): + if "/" in bound: + numerator, denominator = map(float, bound.split("/")) + return numerator / denominator + return float(bound) + + +def parse_c_file(filepath, func_regex): + with open(filepath, "r") as file: + content = file.read() + + pattern = re.compile(rf"(?s)(// ## PRE(?:.*?\n)+?)\s*([\w\s\*]+?)\s+({func_regex})\s*\(([^)]*)\)") + + matches = list(pattern.finditer(content)) + + if not matches: + exit(f"No functions found with the regex: {func_regex}") + + functions = [] + + for match in matches: + comments, return_type, func_name, params = match.groups() + param_comments = re.findall(r"// ## PRE (\w+):\s*([-+.\d/]+),\s*([-+.\d/]+)", comments) + bounds = { + name: { + "min": parse_bound(min_val), + "max": parse_bound(max_val), + } + for name, min_val, max_val in param_comments + } + params = [param.strip() for param in params.split(",") if param.strip()] + functions.append((func_name, bounds, params, return_type.strip())) + + return functions, content, matches + + +def create_driver_function(functions, num_samples_per_func): + driver_code = [ + "#undef double", + "", + "#include ", + "#include ", + "#include ", + "#include ", + ] + + driver_code.append("#include ") + driver_code.append("#include ") + driver_code.append("#include ") + driver_code.append("#include ") + + driver_code.append("") + driver_code.append("int main(int argc, char* argv[]) {") + driver_code.append(' std::string output_path = "";') + driver_code.append("") + driver_code.append(" for (int i = 1; i < argc; ++i) {") + driver_code.append(' if (std::strcmp(argv[i], "--output-path") == 0) {') + driver_code.append(" if (i + 1 < argc) {") + driver_code.append(" output_path = argv[i + 1];") + driver_code.append(" i++;") + driver_code.append(" } else {") + driver_code.append(' std::cerr << "Error: --output-path requires a path argument." << std::endl;') + driver_code.append(" return 1;") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append("") + driver_code.append(" bool save_outputs = !output_path.empty();") + driver_code.append("") + driver_code.append(" std::mt19937 gen(42);") + driver_code.append("") + + driver_code.append(" std::ofstream ofs;") + driver_code.append(" if (save_outputs) {") + driver_code.append(" ofs.open(output_path);") + driver_code.append(" if (!ofs) {") + driver_code.append(' std::cerr << "Failed to open output file: " << output_path << std::endl;') + driver_code.append(" return 1;") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append("") + + for func_name, bounds, params, return_type in functions: + for param in params: + param_tokens = param.strip().split() + if len(param_tokens) >= 2: + param_name = param_tokens[-1] + else: + exit(f"Cannot parse parameter: {param}") + try: + min_val = bounds[param_name]["min"] + max_val = bounds[param_name]["max"] + except KeyError: + exit( + f"WARNING: Bounds not found for {param_name} in function {func_name}, manually specify the bounds." + ) + dist_name = f"{func_name}_{param_name}_dist" + driver_code.append(f" std::uniform_real_distribution<{return_type}> {dist_name}({min_val}, {max_val});") + driver_code.append("") + + driver_code.append(" double sum = 0.;") + driver_code.append("") + + driver_code.append(" auto start_time = std::chrono::high_resolution_clock::now();") + driver_code.append("") + + for func_name, bounds, params, return_type in functions: + driver_code.append(f" for (int i = 0; i < {num_samples_per_func}; ++i) {{") + + call_params = [] + for param in params: + param_tokens = param.strip().split() + if len(param_tokens) >= 2: + param_name = param_tokens[-1] + else: + exit(f"Cannot parse parameter: {param}") + dist_name = f"{func_name}_{param_name}_dist" + param_value = f"{dist_name}(gen)" + param_var_name = f"{param_name}_val" + driver_code.append(f" double {param_var_name} = {param_value};") + call_params.append(param_var_name) + + call_params_str = ", ".join(call_params) + + driver_code.append(f" double res = {func_name}({call_params_str});") + driver_code.append(" sum += res;") + + driver_code.append(" if (save_outputs) {") + driver_code.append( + ' ofs << std::setprecision(std::numeric_limits::digits10 + 1) << res << "\\n";' + ) + driver_code.append(" }") + + driver_code.append(" }") + driver_code.append("") + + driver_code.append(' std::cout << "Sum: " << sum << std::endl;') + driver_code.append(" auto end_time = std::chrono::high_resolution_clock::now();") + driver_code.append(" std::chrono::duration elapsed = end_time - start_time;") + driver_code.append(' std::cout << "Total runtime: " << elapsed.count() << " seconds\\n";') + driver_code.append("") + + driver_code.append(" if (save_outputs) {") + driver_code.append(" ofs.close();") + driver_code.append(" }") + driver_code.append("") + + driver_code.append(" return 0;") + driver_code.append("}") + return "\n".join(driver_code) + + +def main(): + if len(sys.argv) < 4: + exit( + f"Usage: fpopt-golden-driver-generator.py [func_regex] [num_samples_per_func (default: {DEFAULT_NUM_SAMPLES})]" + ) + + source_path = sys.argv[1] + dest_path = sys.argv[2] + PREC = sys.argv[3] + func_regex = sys.argv[4] if len(sys.argv) > 4 else DEFAULT_REGEX + num_samples_per_func = int(sys.argv[5]) if len(sys.argv) > 5 else DEFAULT_NUM_SAMPLES + + functions, original_content, matches = parse_c_file(source_path, func_regex) + + driver_code = create_driver_function(functions, num_samples_per_func) + + with open(source_path, "r") as original_file: + original_content = original_file.read() + + mpfr_header = f'#include "mpfrcpp.hpp"\nconst unsigned int PREC = {PREC};\n#define double mpfrcpp\n\n' + + if matches: + first_match = matches[0] + insert_pos = first_match.start() + modified_content = original_content[:insert_pos] + mpfr_header + original_content[insert_pos:] + else: + exit("No matching functions found to insert mpfr header.") + + with open(dest_path, "w") as new_file: + new_file.write(modified_content) + new_file.write("\n\n" + driver_code) + + print(f"Driver program written to: {dest_path}") + + +if __name__ == "__main__": + main() diff --git a/ablations/fpopt-logged-driver-generator.py b/ablations/fpopt-logged-driver-generator.py new file mode 100644 index 0000000..16a0502 --- /dev/null +++ b/ablations/fpopt-logged-driver-generator.py @@ -0,0 +1,202 @@ +import os +import sys +import re +import random +import numpy as np + +DEFAULT_NUM_SAMPLES = 100000 +DEFAULT_REGEX = "ex\\d+" + +np.random.seed(42) + + +def parse_bound(bound): + if "/" in bound: + numerator, denominator = map(float, bound.split("/")) + return numerator / denominator + return float(bound) + + +def parse_c_file(filepath, func_regex): + with open(filepath, "r") as file: + content = file.read() + + pattern = re.compile(rf"(?s)(// ## PRE(?:.*?\n)+?)\s*([\w\s\*]+?)\s+({func_regex})\s*\(([^)]*)\)") + + matches = pattern.findall(content) + + if not matches: + exit(f"No functions found with the regex: {func_regex}") + + functions = [] + + for comments, return_type, func_name, params in matches: + param_comments = re.findall(r"// ## PRE (\w+):\s*([-+.\d/]+),\s*([-+.\d/]+)", comments) + bounds = { + name: { + "min": parse_bound(min_val), + "max": parse_bound(max_val), + } + for name, min_val, max_val in param_comments + } + params = [param.strip() for param in params.split(",") if param.strip()] + functions.append((func_name, bounds, params, return_type.strip())) + + return functions + + +def create_driver_function(functions, num_samples_per_func): + driver_code = [ + "#include ", + "#include ", + "#include ", + "#include ", + ] + + driver_code.append("#include ") + driver_code.append("#include ") + driver_code.append("#include ") + driver_code.append("#include ") + + driver_code.append("") + driver_code.append("int main(int argc, char* argv[]) {") + driver_code.append(' std::string output_path = "";') + driver_code.append("") + driver_code.append(" for (int i = 1; i < argc; ++i) {") + driver_code.append(' if (std::strcmp(argv[i], "--output-path") == 0) {') + driver_code.append(" if (i + 1 < argc) {") + driver_code.append(" output_path = argv[i + 1];") + driver_code.append(" i++;") + driver_code.append(" } else {") + driver_code.append(' std::cerr << "Error: --output-path requires a path argument." << std::endl;') + driver_code.append(" return 1;") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append("") + driver_code.append(" bool save_outputs = !output_path.empty();") + driver_code.append("") + driver_code.append(" std::mt19937 gen(42);") + driver_code.append("") + driver_code.append(" std::ofstream ofs;") + driver_code.append(" if (save_outputs) {") + driver_code.append(" ofs.open(output_path);") + driver_code.append(" if (!ofs) {") + driver_code.append(' std::cerr << "Failed to open output file: " << output_path << std::endl;') + driver_code.append(" return 1;") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append("") + + driver_code.append(" initializeLogger();") + + for func_name, bounds, params, return_type in functions: + for param in params: + param_tokens = param.strip().split() + if len(param_tokens) >= 2: + param_name = param_tokens[-1] + else: + exit(f"Cannot parse parameter: {param}") + try: + min_val = bounds[param_name]["min"] + max_val = bounds[param_name]["max"] + except KeyError: + exit( + f"WARNING: Bounds not found for {param_name} in function {func_name}, manually specify the bounds." + ) + dist_name = f"{func_name}_{param_name}_dist" + driver_code.append(f" std::uniform_real_distribution<{return_type}> {dist_name}({min_val}, {max_val});") + driver_code.append("") + + driver_code.append(" double sum = 0.;") + driver_code.append("") + + driver_code.append(" auto start_time = std::chrono::high_resolution_clock::now();") + driver_code.append("") + + for func_name, bounds, params, return_type in functions: + driver_code.append(f" for (int i = 0; i < {num_samples_per_func}; ++i) {{") + + call_params = [] + for param in params: + param_tokens = param.strip().split() + if len(param_tokens) >= 2: + param_name = param_tokens[-1] + else: + exit(f"Cannot parse parameter: {param}") + dist_name = f"{func_name}_{param_name}_dist" + param_value = f"{dist_name}(gen)" + call_params.append(param_value) + + driver_code.append( + f" double res = __enzyme_autodiff<{return_type}>((void *) {func_name}, {', '.join(call_params)});" + ) + driver_code.append(" sum += res;") + driver_code.append(" if (save_outputs) {") + driver_code.append( + ' ofs << std::setprecision(std::numeric_limits::digits10 + 1) << res << "\\n";' + ) + driver_code.append(" }") + driver_code.append(" }") + driver_code.append("") + + driver_code.append(' std::cout << "Sum: " << sum << std::endl;') + driver_code.append(" auto end_time = std::chrono::high_resolution_clock::now();") + driver_code.append(" std::chrono::duration elapsed = end_time - start_time;") + driver_code.append(' std::cout << "Total runtime: " << elapsed.count() << " seconds\\n";') + driver_code.append("") + + driver_code.append(" if (save_outputs) {") + driver_code.append(" ofs.close();") + driver_code.append(" }") + driver_code.append("") + + driver_code.append(" printLogger();") + driver_code.append(" destroyLogger();") + driver_code.append(" return 0;") + driver_code.append("}") + return "\n".join(driver_code) + + +def main(): + if len(sys.argv) < 3: + exit( + f"Usage: fpopt-logged-driver-generator.py [func_regex] [num_samples_per_func (default: {DEFAULT_NUM_SAMPLES})]" + ) + + source_path = sys.argv[1] + dest_path = sys.argv[2] + func_regex = sys.argv[3] if len(sys.argv) > 3 else DEFAULT_REGEX + num_samples_per_func = int(sys.argv[4]) if len(sys.argv) > 4 else DEFAULT_NUM_SAMPLES + + functions = parse_c_file(source_path, func_regex) + driver_code = create_driver_function(functions, num_samples_per_func) + + with open(source_path, "r") as original_file: + original_content = original_file.read() + + code_to_insert = """#include "fp-logger.hpp" + +void thisIsNeverCalledAndJustForTheLinker() { + enzymeLogError("", 0.0); + enzymeLogGrad("", 0.0); + enzymeLogValue("", 0.0, 2, nullptr); +} + +int enzyme_dup; +int enzyme_dupnoneed; +int enzyme_out; +int enzyme_const; + +template +return_type __enzyme_autodiff(void *, T...);""" + + with open(dest_path, "w") as new_file: + new_file.write(original_content) + new_file.write("\n\n" + code_to_insert + "\n\n" + driver_code) + + print(f"Driver function appended to the new file: {dest_path}") + + +if __name__ == "__main__": + main() diff --git a/ablations/fpopt-original-driver-generator.py b/ablations/fpopt-original-driver-generator.py new file mode 100644 index 0000000..0f1ed2d --- /dev/null +++ b/ablations/fpopt-original-driver-generator.py @@ -0,0 +1,190 @@ +import os +import sys +import re +import random +import numpy as np + +DEFAULT_NUM_SAMPLES = 100000 +DEFAULT_REGEX = "ex\\d+" + +np.random.seed(42) + + +def parse_bound(bound): + if "/" in bound: + numerator, denominator = map(float, bound.split("/")) + return numerator / denominator + return float(bound) + + +def parse_c_file(filepath, func_regex): + with open(filepath, "r") as file: + content = file.read() + + pattern = re.compile(rf"(?s)(// ## PRE(?:.*?\n)+?)\s*([\w\s\*]+?)\s+({func_regex})\s*\(([^)]*)\)") + + matches = pattern.findall(content) + + if not matches: + exit(f"No functions found with the regex: {func_regex}") + + functions = [] + + for comments, return_type, func_name, params in matches: + param_comments = re.findall(r"// ## PRE (\w+):\s*([-+.\d/]+),\s*([-+.\d/]+)", comments) + bounds = { + name: { + "min": parse_bound(min_val), + "max": parse_bound(max_val), + } + for name, min_val, max_val in param_comments + } + params = [param.strip() for param in params.split(",") if param.strip()] + functions.append((func_name, bounds, params, return_type.strip())) + + return functions + + +def create_driver_function(functions, num_samples_per_func): + driver_code = [ + "#include ", + "#include ", + "#include ", + "#include ", + ] + + driver_code.append("#include ") + driver_code.append("#include ") + driver_code.append("#include ") + driver_code.append("#include ") + + driver_code.append("") + driver_code.append("int main(int argc, char* argv[]) {") + driver_code.append(' std::string output_path = "";') + driver_code.append("") + driver_code.append(" for (int i = 1; i < argc; ++i) {") + driver_code.append(' if (std::strcmp(argv[i], "--output-path") == 0) {') + driver_code.append(" if (i + 1 < argc) {") + driver_code.append(" output_path = argv[i + 1];") + driver_code.append(" i++;") + driver_code.append(" } else {") + driver_code.append(' std::cerr << "Error: --output-path requires a path argument." << std::endl;') + driver_code.append(" return 1;") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append("") + driver_code.append(" bool save_outputs = !output_path.empty();") + driver_code.append("") + driver_code.append(" std::mt19937 gen(42);") + driver_code.append("") + driver_code.append(" std::ofstream ofs;") + driver_code.append(" if (save_outputs) {") + driver_code.append(" ofs.open(output_path);") + driver_code.append(" if (!ofs) {") + driver_code.append(' std::cerr << "Failed to open output file: " << output_path << std::endl;') + driver_code.append(" return 1;") + driver_code.append(" }") + driver_code.append(" }") + driver_code.append("") + + for func_name, bounds, params, return_type in functions: + for param in params: + param_tokens = param.strip().split() + if len(param_tokens) >= 2: + param_name = param_tokens[-1] + else: + exit(f"Cannot parse parameter: {param}") + try: + min_val = bounds[param_name]["min"] + max_val = bounds[param_name]["max"] + except KeyError: + exit( + f"WARNING: Bounds not found for {param_name} in function {func_name}, manually specify the bounds." + ) + dist_name = f"{func_name}_{param_name}_dist" + driver_code.append(f" std::uniform_real_distribution<{return_type}> {dist_name}({min_val}, {max_val});") + driver_code.append("") + + driver_code.append(" double sum = 0.;") + driver_code.append("") + + driver_code.append(" auto start_time = std::chrono::high_resolution_clock::now();") + driver_code.append("") + + for func_name, bounds, params, return_type in functions: + driver_code.append(f" for (int i = 0; i < {num_samples_per_func}; ++i) {{") + + call_params = [] + for param in params: + param_tokens = param.strip().split() + if len(param_tokens) >= 2: + param_name = param_tokens[-1] + else: + exit(f"Cannot parse parameter: {param}") + dist_name = f"{func_name}_{param_name}_dist" + param_value = f"{dist_name}(gen)" + param_var_name = f"{param_name}_val" + driver_code.append(f" double {param_var_name} = {param_value};") + call_params.append(param_var_name) + + call_params_str = ", ".join(call_params) + + driver_code.append(f" double res = {func_name}({call_params_str});") + driver_code.append(" sum += res;") + + driver_code.append(" if (save_outputs) {") + driver_code.append( + ' ofs << std::setprecision(std::numeric_limits::digits10 + 1) << res << "\\n";' + ) + driver_code.append(" }") + + driver_code.append(" }") + driver_code.append("") + + driver_code.append(' std::cout << "Sum: " << sum << std::endl;') + driver_code.append(" auto end_time = std::chrono::high_resolution_clock::now();") + driver_code.append(" std::chrono::duration elapsed = end_time - start_time;") + driver_code.append(' std::cout << "Total runtime: " << elapsed.count() << " seconds\\n";') + driver_code.append("") + + driver_code.append(" if (save_outputs) {") + driver_code.append(" ofs.close();") + driver_code.append(" }") + driver_code.append("") + + driver_code.append(" return 0;") + driver_code.append("}") + return "\n".join(driver_code) + + +def main(): + if len(sys.argv) < 3: + exit( + "Usage: fpopt-logged-driver-generator.py [func_regex] [num_samples_per_func (default: 10000)]" + ) + + source_path = sys.argv[1] + dest_path = sys.argv[2] + func_regex = sys.argv[3] if len(sys.argv) > 3 else DEFAULT_REGEX + num_samples_per_func = int(sys.argv[4]) if len(sys.argv) > 4 else DEFAULT_NUM_SAMPLES + + if len(sys.argv) <= 2: + print(f"WARNING: No regex provided for target function names. Using default regex: {DEFAULT_REGEX}") + + functions = parse_c_file(source_path, func_regex) + + driver_code = create_driver_function(functions, num_samples_per_func) + + with open(source_path, "r") as original_file: + original_content = original_file.read() + + with open(dest_path, "w") as new_file: + new_file.write(original_content) + new_file.write("\n\n" + driver_code) + + print(f"Driver program written to: {dest_path}") + + +if __name__ == "__main__": + main() diff --git a/ablations/mpfrcpp.hpp b/ablations/mpfrcpp.hpp new file mode 100644 index 0000000..7d1686f --- /dev/null +++ b/ablations/mpfrcpp.hpp @@ -0,0 +1,261 @@ +// Adapted from https://github.com/jhueckelheim/force/blob/master/include/mpfrcpp_tpl.h +#ifndef mpfrcpp_h +#define mpfrcpp_h + +#include +#include +#include +#include +#include + +template class mpfrcpp { +public: + mpfr_t value; + + mpfrcpp() { mpfr_init2(value, MPFRPREC); } + mpfrcpp(const float v) { + mpfr_init2(value, MPFRPREC); + mpfr_set_flt(value, v, MPFR_RNDN); + } + mpfrcpp(const double v) { + mpfr_init2(value, MPFRPREC); + mpfr_set_d(value, v, MPFR_RNDN); + } + mpfrcpp(const long double v) { + mpfr_init2(value, MPFRPREC); + mpfr_set_ld(value, v, MPFR_RNDN); + } + mpfrcpp(const mpfrcpp &v) { + mpfr_init2(value, MPFRPREC); + mpfr_set(value, v.value, MPFR_RNDN); + } + mpfrcpp(const mpfr_t &v) { + mpfr_init2(value, MPFRPREC); + mpfr_set(value, v, MPFR_RNDN); + } + mpfrcpp(const char *v) { + mpfr_init2(value, MPFRPREC); + mpfr_set_str(value, v, 10, MPFR_RNDN); + } + ~mpfrcpp() { mpfr_clear(value); } + + mpfrcpp &operator=(const mpfrcpp &g1) { + mpfr_set(value, g1.value, MPFR_RNDN); + return *this; + } + mpfrcpp &operator=(const double &g1) { + mpfr_set_d(value, g1, MPFR_RNDN); + return *this; + } + + mpfrcpp &operator+=(const double &g1) { + mpfr_add_d(value, value, g1, MPFR_RNDN); + return *this; + } + mpfrcpp &operator-=(const double &g1) { + mpfr_sub_d(value, value, g1, MPFR_RNDN); + return *this; + } + mpfrcpp &operator*=(const double &g1) { + mpfr_mul_d(value, value, g1, MPFR_RNDN); + return *this; + } + mpfrcpp &operator/=(const double &g1) { + mpfr_div_d(value, value, g1, MPFR_RNDN); + return *this; + } + + operator double() const { return mpfr_get_d(value, MPFR_RNDN); } +}; + +template +mpfrcpp operator+(const mpfrcpp &g1, const double &g2) { + mpfrcpp res; + mpfr_add_d(res.value, g1.value, g2, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator-(const mpfrcpp &g1, const double &g2) { + mpfrcpp res; + mpfr_sub_d(res.value, g1.value, g2, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator*(const mpfrcpp &g1, const double &g2) { + mpfrcpp res; + mpfr_mul_d(res.value, g1.value, g2, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator/(const mpfrcpp &g1, const double &g2) { + mpfrcpp res; + mpfr_div_d(res.value, g1.value, g2, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator+(const double &g1, const mpfrcpp &g2) { + mpfrcpp res; + mpfr_add_d(res.value, g2.value, g1, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator-(const double &g1, const mpfrcpp &g2) { + mpfrcpp res; + mpfr_d_sub(res.value, g1, g2.value, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator*(const double &g1, const mpfrcpp &g2) { + mpfrcpp res; + mpfr_mul_d(res.value, g2.value, g1, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator/(const double &g1, const mpfrcpp &g2) { + mpfrcpp res; + mpfr_d_div(res.value, g1, g2.value, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator+(const mpfrcpp &g1, + const mpfrcpp &g2) { + mpfrcpp res; + mpfr_add(res.value, g1.value, g2.value, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator-(const mpfrcpp &g1, + const mpfrcpp &g2) { + mpfrcpp res; + mpfr_sub(res.value, g1.value, g2.value, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator*(const mpfrcpp &g1, + const mpfrcpp &g2) { + mpfrcpp res; + mpfr_mul(res.value, g1.value, g2.value, MPFR_RNDN); + return res; +} + +template +mpfrcpp operator/(const mpfrcpp &g1, + const mpfrcpp &g2) { + mpfrcpp res; + mpfr_div(res.value, g1.value, g2.value, MPFR_RNDN); + return res; +} + +template +bool operator>=(const mpfrcpp &g1, const double &g2) { + return mpfr_cmp_d(g1.value, g2) >= 0; +} + +template +bool operator>(const mpfrcpp &g1, const double &g2) { + return mpfr_cmp_d(g1.value, g2) > 0; +} + +template +bool operator<=(const mpfrcpp &g1, const double &g2) { + return mpfr_cmp_d(g1.value, g2) <= 0; +} + +template +bool operator<(const mpfrcpp &g1, const double &g2) { + return mpfr_cmp_d(g1.value, g2) < 0; +} + +template +bool operator>=(const double &g1, const mpfrcpp &g2) { + return mpfr_cmp_d(g2.value, g1) <= 0; +} + +template +bool operator>(const double &g1, const mpfrcpp &g2) { + return mpfr_cmp_d(g2.value, g1) < 0; +} + +template +bool operator<=(const double &g1, const mpfrcpp &g2) { + return mpfr_cmp_d(g2.value, g1) >= 0; +} + +template +bool operator<(const double &g1, const mpfrcpp &g2) { + return mpfr_cmp_d(g2.value, g1) > 0; +} + +template +int operator>=(const mpfrcpp &g1, const mpfrcpp &g2) { + return mpfr_greaterequal_p(g1.value, g2.value); +} + +template +int operator>(const mpfrcpp &g1, const mpfrcpp &g2) { + return mpfr_greater_p(g1.value, g2.value); +} + +template +int operator<(const mpfrcpp &g1, const mpfrcpp &g2) { + return mpfr_greater_p(g2.value, g1.value); +} + +template +mpfrcpp fabs(const mpfrcpp &g1) { + mpfrcpp res; + mpfr_abs(res.value, g1.value, MPFR_RNDN); + return res; +} + +template +mpfrcpp pow(const mpfrcpp &g1, double expd) { + mpfrcpp res; + mpfr_t exp_mpfr; + mpfr_init2(exp_mpfr, MPFRPREC); + mpfr_set_d(exp_mpfr, expd, MPFR_RNDN); + mpfr_pow(res.value, g1.value, exp_mpfr, MPFR_RNDN); + mpfr_clear(exp_mpfr); + return res; +} + +template +mpfrcpp sqrt(const mpfrcpp &g1) { + mpfrcpp res; + mpfr_sqrt(res.value, g1.value, MPFR_RNDN); + return res; +} + +template +std::ostream &operator<<(std::ostream &ost, const mpfrcpp &ad) { + char *abc = NULL; + mpfr_exp_t i; + if (ad >= mpfrcpp(0.0)) { + abc = mpfr_get_str(NULL, &i, 10, 0, ad.value, MPFR_RNDN); + ost << "0." << abc << "e" << i; + } else { + abc = mpfr_get_str(NULL, &i, 10, 0, (-ad).value, MPFR_RNDN); + ost << "-0." << abc << "e" << i; + } + mpfr_free_str(abc); + return ost; +} + +template +mpfrcpp convert(const mpfrcpp &from) { + mpfrcpp to; + mpfr_set(to.value, from.value, MPFR_RNDN); + return to; +} + +#endif diff --git a/ablations/run-all.py b/ablations/run-all.py new file mode 100644 index 0000000..99ed197 --- /dev/null +++ b/ablations/run-all.py @@ -0,0 +1,192 @@ +import os +import re +import subprocess +import glob +import multiprocessing +import argparse + + +def extract_functions_from_c_file(content, func_regex="ex\\d+"): + functions = [] + lines = content.splitlines() + i = 0 + while i < len(lines): + line = lines[i] + func_def_pattern = re.compile(rf"^\s*(.*?)\s+({func_regex})\s*\((.*?)\)\s*\{{\s*$") + match = func_def_pattern.match(line) + if match: + return_type = match.group(1).strip() + func_name = match.group(2) + params = match.group(3).strip() + comments = [] + j = i - 1 + while j >= 0: + prev_line = lines[j] + if prev_line.strip().startswith("//"): + comments.insert(0, prev_line) + j -= 1 + elif prev_line.strip() == "": + j -= 1 + else: + break + func_body_lines = [line] + brace_level = line.count("{") - line.count("}") + i += 1 + while i < len(lines) and brace_level > 0: + func_line = lines[i] + func_body_lines.append(func_line) + brace_level += func_line.count("{") + brace_level -= func_line.count("}") + i += 1 + func_body = "\n".join(func_body_lines) + comments_str = "\n".join(comments) + functions.append( + { + "comments": comments_str, + "return_type": return_type, + "func_name": func_name, + "params": params, + "func_body": func_body, + } + ) + else: + i += 1 + return functions + + +def process_function_task(func, base_name, plot_only, ablation_type): + func_name = func["func_name"] + return_type = func["return_type"] + params = func["params"] + comments = func["comments"] + print(f"Processing function: {func_name}") + + prefix = f"{base_name}-{func_name}-" + example_c_filename = f"{prefix}example.c" + example_c_filepath = os.path.join("tmp", example_c_filename) + + func_body_lines = func["func_body"].split("\n") + func_signature_line = f"__attribute__((noinline))\n{return_type} example({params}) {{" + func_body_lines[0] = func_signature_line + func_code = comments + "\n" + "\n".join(func_body_lines) + includes = "#include \n#include \n#define TRUE 1\n#define FALSE 0\n" + example_c_content = includes + "\n" + func_code + os.makedirs("tmp", exist_ok=True) + with open(example_c_filepath, "w") as f: + f.write(example_c_content) + + base_command = [ + "python3", + "ablation.py", + "--prefix", + prefix, + "--ablation-type", + ablation_type, + ] + if plot_only: + base_command.append("--plot-only") + + try: + print("Running command:", " ".join(base_command)) + subprocess.check_call(base_command) + except subprocess.CalledProcessError: + print( + f"Error running ablation.py for function {func_name} in base {base_name}. Retrying with --disable-preopt." + ) + retry_command = [ + "python3", + "ablation.py", + "--prefix", + prefix, + "--ablation-type", + ablation_type, + "--disable-preopt", + ] + if plot_only: + retry_command.append("--plot-only") + try: + print("Running retry command:", " ".join(retry_command)) + subprocess.check_call(retry_command) + except subprocess.CalledProcessError: + print(f"Error running ablation.py with --disable-preopt for function {func_name} in base {base_name}") + + +def main(): + parser = argparse.ArgumentParser(description="Process benchmarks and run ablation studies on all functions.") + parser.add_argument("-j", type=int, help="Number of parallel tasks", default=1) + parser.add_argument("--regen", action="store_true", help="Force regeneration of .c files") + parser.add_argument("--plot-only", action="store_true", help="Only plot the results") + parser.add_argument( + "--ablation-type", + type=str, + choices=["widen-range", "cost-model"], + default="widen-range", + help="Type of ablation study to run on all benchmarks (default: widen-range)", + ) + args = parser.parse_args() + + num_parallel_tasks = args.j + force_regen = args.regen + plot_only = args.plot_only + ablation_type = args.ablation_type + + source_dir = "../benchmarks" + exported_dir = "exported" + racket_script = "../export.rkt" + fpcore_files = glob.glob(os.path.join(source_dir, "*.fpcore")) + + if not fpcore_files: + print("No .fpcore files found in the benchmarks directory.") + return + + os.makedirs("tmp", exist_ok=True) + os.makedirs("exported", exist_ok=True) + + tasks = [] + + for fpcore_file in fpcore_files: + filename = os.path.basename(fpcore_file) + base_name = os.path.splitext(filename)[0] + c_filename = f"{base_name}.fpcore.c" + c_filepath = os.path.join(exported_dir, c_filename) + + if not force_regen and os.path.exists(c_filepath): + print(f"{c_filename} already exists. Skipping generation.") + else: + print(f"Generating {c_filename} using Racket script...") + try: + cmd = ["racket", racket_script, fpcore_file, c_filepath] + print("Running command:", " ".join(cmd)) + subprocess.check_call(cmd) + except subprocess.CalledProcessError: + print(f"Error running export.rkt on {filename}") + continue + + print(f"Processing generated .c file: {c_filename}") + with open(c_filepath, "r") as f: + content = f.read() + functions = extract_functions_from_c_file(content) + if not functions: + print(f"No ex functions found in {c_filename}") + continue + for func in functions: + func_name = func["func_name"] + print(f"Found function: {func_name}") + task = (func, base_name, plot_only, ablation_type) + tasks.append(task) + + if tasks: + if num_parallel_tasks == 1: + for task in tasks: + process_function_task(*task) + else: + with multiprocessing.Pool(num_parallel_tasks) as pool: + pool.starmap(process_function_task, tasks) + else: + print("No functions to process.") + + print("Processing completed.") + + +if __name__ == "__main__": + main()