Skip to content

Commit

Permalink
cost model ablation
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Dec 5, 2024
1 parent e76948b commit 3fea89e
Showing 1 changed file with 162 additions and 28 deletions.
190 changes: 162 additions & 28 deletions experiments/ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
from statistics import mean
import pickle
from tqdm import tqdm, trange
from tqdm import trange
from matplotlib import rcParams
from concurrent.futures import ProcessPoolExecutor, as_completed

Expand Down Expand Up @@ -71,7 +71,6 @@

SRC = "example.c"
LOGGER = "fp-logger.cpp"
EXE = ["example.exe", "example-logged.exe", "example-fpopt.exe"]
NUM_RUNS = 10
DRIVER_NUM_SAMPLES = 10000000
LOG_NUM_SAMPLES = 10000
Expand Down Expand Up @@ -456,7 +455,6 @@ def benchmark(tmp_dir, logs_dir, original_prefix, prefix, plots_dir, fpoptflags,
data_tuples_sorted = sorted(data_tuples, key=lambda x: x[0])
sorted_budgets, sorted_optimized_binaries = zip(*data_tuples_sorted) if data_tuples_sorted else ([], [])

# Measure runtimes serially based on sorted budgets
sorted_runtimes = []
for cost, output_binary in zip(sorted_budgets, sorted_optimized_binaries):
avg_runtime = measure_runtime(tmp_dir, prefix, output_binary, NUM_RUNS)
Expand Down Expand Up @@ -507,7 +505,7 @@ def remove_cache_dir():


def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_format="png"):
ablation_data_file = os.path.join(tmp_dir, f"{prefix}ablation.pkl")
ablation_data_file = os.path.join(tmp_dir, f"{prefix}ablation-widen-range.pkl")
if not os.path.exists(ablation_data_file):
print(f"Ablation data file {ablation_data_file} does not exist. Cannot plot.")
sys.exit(1)
Expand All @@ -533,7 +531,6 @@ def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_fo
original_runtime = data["original_runtime"]
original_error = data["original_error"]

# Filter out entries where runtimes or errors are None
data_points = list(zip(runtimes, errors))
filtered_data = [(r, e) for r, e in data_points if r is not None and e is not None]
if not filtered_data:
Expand All @@ -542,7 +539,6 @@ def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_fo
runtimes_filtered, errors_filtered = zip(*filtered_data)
color = next(color_iter)
plt.scatter(runtimes_filtered, errors_filtered, label=f"widen-range={X}", color=color)
# Calculate Pareto Front
points = np.array(filtered_data)
sorted_indices = np.argsort(points[:, 0])
sorted_points = points[sorted_indices]
Expand All @@ -562,7 +558,6 @@ def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_fo
color=color,
)

# Plot the original program
plt.scatter(original_runtime, original_error, marker="x", color="black", s=100, label="Original Program")

plt.xlabel("Runtimes (seconds)")
Expand All @@ -572,12 +567,94 @@ def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_fo
plt.legend()
plt.grid(True)

plot_filename = os.path.join(plots_dir, f"{prefix}ablation_pareto_front.{output_format}")
plot_filename = os.path.join(plots_dir, f"{prefix}ablation_widen_range_pareto_front.{output_format}")
plt.savefig(plot_filename, bbox_inches="tight", dpi=300)
plt.close()
print(f"Ablation plot saved to {plot_filename}")


def plot_ablation_results_cost_model(tmp_dir, plots_dir, original_prefix, prefix, output_format="png"):
ablation_data_file = os.path.join(tmp_dir, f"{prefix}ablation-cost-model.pkl")
if not os.path.exists(ablation_data_file):
print(f"Ablation data file {ablation_data_file} does not exist. Cannot plot.")
sys.exit(1)
with open(ablation_data_file, "rb") as f:
all_data = pickle.load(f)

rcParams["font.size"] = 20
rcParams["axes.titlesize"] = 24
rcParams["axes.labelsize"] = 20
rcParams["xtick.labelsize"] = 18
rcParams["ytick.labelsize"] = 18
rcParams["legend.fontsize"] = 18

plt.figure(figsize=(10, 8))

colors = ["blue", "green"]
labels = ["With Cost Model", "Without Cost Model"]

for idx, key in enumerate(["with_cost_model", "without_cost_model"]):
data = all_data[key]
budgets = data["budgets"]
runtimes = data["runtimes"]
errors = data["errors"]
original_runtime = data["original_runtime"]
original_error = data["original_error"]

data_points = list(zip(runtimes, errors))
filtered_data = [(r, e) for r, e in data_points if r is not None and e is not None]
if not filtered_data:
print(f"No valid data to plot for {labels[idx]}.")
continue
runtimes_filtered, errors_filtered = zip(*filtered_data)
color = colors[idx]
plt.scatter(runtimes_filtered, errors_filtered, label=labels[idx], color=color)
points = np.array(filtered_data)
sorted_indices = np.argsort(points[:, 0])
sorted_points = points[sorted_indices]

pareto_front = [sorted_points[0]]
for point in sorted_points[1:]:
if point[1] < pareto_front[-1][1]:
pareto_front.append(point)

pareto_front = np.array(pareto_front)

plt.step(
pareto_front[:, 0],
pareto_front[:, 1],
where="post",
linestyle="-",
color=color,
)

plt.scatter(original_runtime, original_error, marker="x", color="black", s=100, label="Original Program")

plt.xlabel("Runtimes (seconds)")
plt.ylabel("Relative Errors (%)")
plt.title("Pareto Fronts for Cost Model Ablation")
plt.yscale("log")
plt.legend()
plt.grid(True)

plot_filename = os.path.join(plots_dir, f"{prefix}ablation_cost_model_pareto_front.{output_format}")
plt.savefig(plot_filename, bbox_inches="tight", dpi=300)
plt.close()
print(f"Ablation plot saved to {plot_filename}")


def remove_mllvm_flag(flags_list, flag_prefix):
new_flags = []
i = 0
while i < len(flags_list):
if flags_list[i] == "-mllvm" and i + 1 < len(flags_list) and flags_list[i + 1].startswith(flag_prefix):
i += 2
else:
new_flags.append(flags_list[i])
i += 1
return new_flags


def main():
parser = argparse.ArgumentParser(description="Run the ablation study with widen-range parameter.")
parser.add_argument("--prefix", type=str, required=True, help="Prefix for intermediate files (e.g., rosa-ex23-)")
Expand All @@ -587,6 +664,13 @@ def main():
parser.add_argument(
"--num-parallel", type=int, default=16, help="Number of parallel processes to use (default: 16)"
)
parser.add_argument(
"--ablation-type",
type=str,
choices=["widen-range", "cost-model"],
default="widen-range",
help="Type of ablation study to perform (default: widen-range)",
)
args = parser.parse_args()

original_prefix = args.prefix
Expand All @@ -607,47 +691,97 @@ def main():
clean(tmp_dir, logs_dir, plots_dir)
sys.exit(0)
elif args.plot_only:
plot_ablation_results(tmp_dir, plots_dir, original_prefix, original_prefix, args.output_format)
if args.ablation_type == "widen-range":
plot_ablation_results(tmp_dir, plots_dir, original_prefix, original_prefix, args.output_format)
elif args.ablation_type == "cost-model":
plot_ablation_results_cost_model(tmp_dir, plots_dir, original_prefix, original_prefix, args.output_format)
sys.exit(0)
else:
# Generate example.txt only once with the original prefix
if not os.path.exists(example_txt_path):
generate_example_logged_cpp(tmp_dir, original_prefix, original_prefix)
compile_example_logged_exe(tmp_dir, original_prefix)
generate_example_txt(tmp_dir, original_prefix)

widen_ranges = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]
all_data = {}
for X in widen_ranges:
print(f"=== Running ablation study with widen-range={X} ===")
if args.ablation_type == "widen-range":
widen_ranges = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]
all_data = {}
for X in widen_ranges:
print(f"=== Running ablation study with widen-range={X} ===")
remove_cache_dir()
FPOPTFLAGS_BASE = FPOPTFLAGS_BASE_TEMPLATE.copy()
for idx, flag in enumerate(FPOPTFLAGS_BASE):
if flag.startswith("--fpopt-log-path="):
FPOPTFLAGS_BASE[idx] = f"--fpopt-log-path={example_txt_path}"
FPOPTFLAGS_BASE.extend(["-mllvm", f"--fpopt-widen-range={X}"])

prefix_with_x = f"{original_prefix}abl-widen-range-{X}-"

data = build_with_benchmark(
tmp_dir,
logs_dir,
plots_dir,
original_prefix,
prefix_with_x,
FPOPTFLAGS_BASE,
example_txt_path,
num_parallel=args.num_parallel,
)

all_data[X] = data

ablation_data_file = os.path.join(tmp_dir, f"{original_prefix}ablation-widen-range.pkl")
with open(ablation_data_file, "wb") as f:
pickle.dump(all_data, f)
print(f"Ablation data saved to {ablation_data_file}")

plot_ablation_results(tmp_dir, plots_dir, original_prefix, original_prefix, args.output_format)

if args.ablation_type == "cost-model":
print("=== Running cost-model ablation study ===")
remove_cache_dir()
FPOPTFLAGS_BASE = FPOPTFLAGS_BASE_TEMPLATE.copy()
for idx, flag in enumerate(FPOPTFLAGS_BASE):
FPOPTFLAGS_WITH_CM = FPOPTFLAGS_BASE_TEMPLATE.copy()
for idx, flag in enumerate(FPOPTFLAGS_WITH_CM):
if flag.startswith("--fpopt-log-path="):
FPOPTFLAGS_BASE[idx] = f"--fpopt-log-path={example_txt_path}"
FPOPTFLAGS_BASE.extend(["-mllvm", f"--fpopt-widen-range={X}"])
FPOPTFLAGS_WITH_CM[idx] = f"--fpopt-log-path={example_txt_path}"
prefix_with_cm = f"{original_prefix}abl-with-cost-model-"

prefix_with_x = f"{original_prefix}abl-widen-range-{X}-"
data_with_cm = build_with_benchmark(
tmp_dir,
logs_dir,
plots_dir,
original_prefix,
prefix_with_cm,
FPOPTFLAGS_WITH_CM,
example_txt_path,
num_parallel=args.num_parallel,
)

remove_cache_dir()
FPOPTFLAGS_NO_CM = remove_mllvm_flag(FPOPTFLAGS_WITH_CM, "--fpopt-cost-model-path=")
prefix_without_cm = f"{original_prefix}abl-without-cost-model-"

data = build_with_benchmark(
data_without_cm = build_with_benchmark(
tmp_dir,
logs_dir,
plots_dir,
original_prefix,
prefix_with_x,
FPOPTFLAGS_BASE,
prefix_without_cm,
FPOPTFLAGS_NO_CM,
example_txt_path,
num_parallel=args.num_parallel,
)

all_data[X] = data
all_data = {
"with_cost_model": data_with_cm,
"without_cost_model": data_without_cm,
}

ablation_data_file = os.path.join(tmp_dir, f"{original_prefix}ablation.pkl")
with open(ablation_data_file, "wb") as f:
pickle.dump(all_data, f)
print(f"Ablation data saved to {ablation_data_file}")
ablation_data_file = os.path.join(tmp_dir, f"{original_prefix}ablation-cost-model.pkl")
with open(ablation_data_file, "wb") as f:
pickle.dump(all_data, f)
print(f"Ablation data saved to {ablation_data_file}")

plot_ablation_results(tmp_dir, plots_dir, original_prefix, original_prefix, args.output_format)
plot_ablation_results_cost_model(tmp_dir, plots_dir, original_prefix, original_prefix, args.output_format)


if __name__ == "__main__":
Expand Down

0 comments on commit 3fea89e

Please sign in to comment.