Skip to content

Commit

Permalink
show prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Dec 6, 2024
1 parent 3fea89e commit d869d81
Showing 1 changed file with 174 additions and 33 deletions.
207 changes: 174 additions & 33 deletions experiments/ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import math
import random
import numpy as np
import json
from statistics import mean
import pickle
from tqdm import trange
Expand Down Expand Up @@ -407,7 +408,6 @@ def process_cost(args):
output_binary = f"example-fpopt-{cost}.exe"

compile_example_fpopt_exe(tmp_dir, prefix, fpoptflags_cost, output=output_binary, verbose=False)

generate_values(tmp_dir, prefix, output_binary)

return cost, output_binary
Expand Down Expand Up @@ -481,6 +481,27 @@ def benchmark(tmp_dir, logs_dir, original_prefix, prefix, plots_dir, fpoptflags,
"original_runtime": original_runtime,
"original_error": rel_err_example,
}

table_json_path = os.path.join("cache", "table.json")
if os.path.exists(table_json_path):
with open(table_json_path, "r") as f:
table_data = json.load(f)
if "costToAccuracyMap" in table_data:
predicted_costs = []
predicted_errors = []
for k, v in table_data["costToAccuracyMap"].items():
try:
cost_val = float(k)
err_val = float(v)
predicted_costs.append(cost_val)
predicted_errors.append(err_val)
except ValueError:
pass
pred_sorted_indices = np.argsort(predicted_costs)
predicted_costs = np.array(predicted_costs)[pred_sorted_indices].tolist()
predicted_errors = np.array(predicted_errors)[pred_sorted_indices].tolist()
data["predicted_data"] = {"costs": predicted_costs, "errors": predicted_errors}

data_file = os.path.join(tmp_dir, f"{prefix}benchmark_data.pkl")
with open(data_file, "wb") as f:
pickle.dump(data, f)
Expand All @@ -504,22 +525,38 @@ def remove_cache_dir():
print("=== Removed existing cache directory ===")


def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_format="png"):
def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_format="png", show_prediction=False):
ablation_data_file = os.path.join(tmp_dir, f"{prefix}ablation-widen-range.pkl")
if not os.path.exists(ablation_data_file):
print(f"Ablation data file {ablation_data_file} does not exist. Cannot plot.")
sys.exit(1)
with open(ablation_data_file, "rb") as f:
all_data = pickle.load(f)

if not all_data:
print("No data to plot.")
sys.exit(1)

rcParams["font.size"] = 20
rcParams["axes.titlesize"] = 24
rcParams["axes.labelsize"] = 20
rcParams["xtick.labelsize"] = 18
rcParams["ytick.labelsize"] = 18
rcParams["legend.fontsize"] = 18

plt.figure(figsize=(10, 8))
if show_prediction and any("predicted_data" in d for d in all_data.values()):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
else:
fig, ax1 = plt.subplots(1, 1, figsize=(10, 8))
ax2 = None

# Extract original runtime/error from the first scenario
first_key = next(iter(all_data))
original_runtime = all_data[first_key]["original_runtime"]
original_error = all_data[first_key]["original_error"]

# Plot original program once
ax1.scatter(original_runtime, original_error, marker="x", color="black", s=100, label="Original Program")

colors = ["blue", "green", "red", "purple", "orange", "brown", "pink", "gray", "olive", "cyan"]
color_iter = iter(colors)
Expand All @@ -528,8 +565,6 @@ def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_fo
budgets = data["budgets"]
runtimes = data["runtimes"]
errors = data["errors"]
original_runtime = data["original_runtime"]
original_error = data["original_error"]

data_points = list(zip(runtimes, errors))
filtered_data = [(r, e) for r, e in data_points if r is not None and e is not None]
Expand All @@ -538,68 +573,113 @@ def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_fo
continue
runtimes_filtered, errors_filtered = zip(*filtered_data)
color = next(color_iter)
plt.scatter(runtimes_filtered, errors_filtered, label=f"widen-range={X}", color=color)
ax1.scatter(runtimes_filtered, errors_filtered, label=f"widen-range={X}", color=color)
points = np.array(filtered_data)
sorted_indices = np.argsort(points[:, 0])
sorted_points = points[sorted_indices]

# Compute pareto front
pareto_front = [sorted_points[0]]
for point in sorted_points[1:]:
if point[1] < pareto_front[-1][1]:
pareto_front.append(point)

pareto_front = np.array(pareto_front)

plt.step(
ax1.step(
pareto_front[:, 0],
pareto_front[:, 1],
where="post",
linestyle="-",
color=color,
)

plt.scatter(original_runtime, original_error, marker="x", color="black", s=100, label="Original Program")
if show_prediction and "predicted_data" in data and ax2 is not None:
p_costs = data["predicted_data"]["costs"]
p_errors = data["predicted_data"]["errors"]
if p_costs and p_errors:
ax2.scatter(p_costs, p_errors, label=f"Prediction (widen-range={X})", color=color)
pred_points = np.column_stack((p_costs, p_errors))
pred_sorted_indices = np.argsort(pred_points[:, 0])
pred_sorted_points = pred_points[pred_sorted_indices]

pred_pareto = [pred_sorted_points[0]]
for pt in pred_sorted_points[1:]:
if pt[1] < pred_pareto[-1][1]:
pred_pareto.append(pt)
pred_pareto = np.array(pred_pareto)

ax2.step(
pred_pareto[:, 0],
pred_pareto[:, 1],
where="post",
linestyle="-",
color=color,
)

plt.xlabel("Runtimes (seconds)")
plt.ylabel("Relative Errors (%)")
plt.title("Pareto Fronts for Different widen-range Values")
plt.yscale("log")
plt.legend()
plt.grid(True)
ax1.set_xlabel("Runtimes (seconds)")
ax1.set_ylabel("Relative Errors (%)")
ax1.set_title("Pareto Fronts for Different widen-range Values")
ax1.set_yscale("log")
ax1.legend()
ax1.grid(True)

if ax2 is not None:
ax2.set_xlabel("Predicted Cost")
ax2.set_ylabel("Predicted Error (%)")
ax2.set_title("Predicted Pareto Fronts")
ax2.set_yscale("log")
ax2.legend()
ax2.grid(True)

plot_filename = os.path.join(plots_dir, f"{prefix}ablation_widen_range_pareto_front.{output_format}")
plt.savefig(plot_filename, bbox_inches="tight", dpi=300)
plt.close()
print(f"Ablation plot saved to {plot_filename}")


def plot_ablation_results_cost_model(tmp_dir, plots_dir, original_prefix, prefix, output_format="png"):
def plot_ablation_results_cost_model(
tmp_dir, plots_dir, original_prefix, prefix, output_format="png", show_prediction=False
):
ablation_data_file = os.path.join(tmp_dir, f"{prefix}ablation-cost-model.pkl")
if not os.path.exists(ablation_data_file):
print(f"Ablation data file {ablation_data_file} does not exist. Cannot plot.")
sys.exit(1)
with open(ablation_data_file, "rb") as f:
all_data = pickle.load(f)

if not all_data:
print("No data to plot.")
sys.exit(1)

rcParams["font.size"] = 20
rcParams["axes.titlesize"] = 24
rcParams["axes.labelsize"] = 20
rcParams["xtick.labelsize"] = 18
rcParams["ytick.labelsize"] = 18
rcParams["legend.fontsize"] = 18

plt.figure(figsize=(10, 8))
if show_prediction and any("predicted_data" in d for d in all_data.values()):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
else:
fig, ax1 = plt.subplots(1, 1, figsize=(10, 8))
ax2 = None

# Extract original runtime/error from the first scenario
first_key = next(iter(all_data))
original_runtime = all_data[first_key]["original_runtime"]
original_error = all_data[first_key]["original_error"]

# Plot original program once
ax1.scatter(original_runtime, original_error, marker="x", color="black", s=100, label="Original Program")

colors = ["blue", "green"]
labels = ["With Cost Model", "Without Cost Model"]
labels = ["Custom Cost Model", "TTI Cost Model"]

for idx, key in enumerate(["with_cost_model", "without_cost_model"]):
data = all_data[key]
budgets = data["budgets"]
runtimes = data["runtimes"]
errors = data["errors"]
original_runtime = data["original_runtime"]
original_error = data["original_error"]

data_points = list(zip(runtimes, errors))
filtered_data = [(r, e) for r, e in data_points if r is not None and e is not None]
Expand All @@ -608,7 +688,7 @@ def plot_ablation_results_cost_model(tmp_dir, plots_dir, original_prefix, prefix
continue
runtimes_filtered, errors_filtered = zip(*filtered_data)
color = colors[idx]
plt.scatter(runtimes_filtered, errors_filtered, label=labels[idx], color=color)
ax1.scatter(runtimes_filtered, errors_filtered, label=labels[idx], color=color)
points = np.array(filtered_data)
sorted_indices = np.argsort(points[:, 0])
sorted_points = points[sorted_indices]
Expand All @@ -620,22 +700,53 @@ def plot_ablation_results_cost_model(tmp_dir, plots_dir, original_prefix, prefix

pareto_front = np.array(pareto_front)

plt.step(
ax1.step(
pareto_front[:, 0],
pareto_front[:, 1],
where="post",
linestyle="-",
color=color,
)

plt.scatter(original_runtime, original_error, marker="x", color="black", s=100, label="Original Program")
if show_prediction and "predicted_data" in data and ax2 is not None:
p_costs = data["predicted_data"]["costs"]
p_errors = data["predicted_data"]["errors"]
if p_costs and p_errors:
ax2.scatter(p_costs, p_errors, label=f"{labels[idx]}", color=color)
pred_points = np.column_stack((p_costs, p_errors))
pred_sorted_indices = np.argsort(pred_points[:, 0])
pred_sorted_points = pred_points[pred_sorted_indices]

pred_pareto = [pred_sorted_points[0]]
for pt in pred_sorted_points[1:]:
if pt[1] < pred_pareto[-1][1]:
pred_pareto.append(pt)
pred_pareto = np.array(pred_pareto)

ax2.step(
pred_pareto[:, 0],
pred_pareto[:, 1],
where="post",
linestyle="-",
color=color,
)

plt.xlabel("Runtimes (seconds)")
plt.ylabel("Relative Errors (%)")
plt.title("Pareto Fronts for Cost Model Ablation")
plt.yscale("log")
plt.legend()
plt.grid(True)
ax1.set_xlabel("Runtimes (seconds)")
ax1.set_ylabel("Relative Errors (%)")
ax1.set_title("Pareto Fronts for Cost Model Ablation")
ax1.set_yscale("symlog", linthresh=1e-14)
ax1.set_ylim(bottom=-1e-14)
ax1.legend()
ax1.grid(True)

if ax2 is not None:
ax2.set_xlabel("Cost Budget")
ax2.set_ylabel("Predicted Error")
ax2.set_title("Predicted Pareto Front")
ax2.set_yscale("symlog", linthresh=1e-14)
ax2.set_ylim(bottom=-1e-14)
ax2.legend()
ax2.grid(True)

plot_filename = os.path.join(plots_dir, f"{prefix}ablation_cost_model_pareto_front.{output_format}")
plt.savefig(plot_filename, bbox_inches="tight", dpi=300)
Expand Down Expand Up @@ -671,6 +782,8 @@ def main():
default="widen-range",
help="Type of ablation study to perform (default: widen-range)",
)
parser.add_argument("--show-prediction", action="store_true", help="Show predicted results in a second subplot")

args = parser.parse_args()

original_prefix = args.prefix
Expand All @@ -692,9 +805,23 @@ def main():
sys.exit(0)
elif args.plot_only:
if args.ablation_type == "widen-range":
plot_ablation_results(tmp_dir, plots_dir, original_prefix, original_prefix, args.output_format)
plot_ablation_results(
tmp_dir,
plots_dir,
original_prefix,
original_prefix,
args.output_format,
show_prediction=args.show_prediction,
)
elif args.ablation_type == "cost-model":
plot_ablation_results_cost_model(tmp_dir, plots_dir, original_prefix, original_prefix, args.output_format)
plot_ablation_results_cost_model(
tmp_dir,
plots_dir,
original_prefix,
original_prefix,
args.output_format,
show_prediction=args.show_prediction,
)
sys.exit(0)
else:
if not os.path.exists(example_txt_path):
Expand Down Expand Up @@ -734,7 +861,14 @@ def main():
pickle.dump(all_data, f)
print(f"Ablation data saved to {ablation_data_file}")

plot_ablation_results(tmp_dir, plots_dir, original_prefix, original_prefix, args.output_format)
plot_ablation_results(
tmp_dir,
plots_dir,
original_prefix,
original_prefix,
args.output_format,
show_prediction=args.show_prediction,
)

if args.ablation_type == "cost-model":
print("=== Running cost-model ablation study ===")
Expand Down Expand Up @@ -781,7 +915,14 @@ def main():
pickle.dump(all_data, f)
print(f"Ablation data saved to {ablation_data_file}")

plot_ablation_results_cost_model(tmp_dir, plots_dir, original_prefix, original_prefix, args.output_format)
plot_ablation_results_cost_model(
tmp_dir,
plots_dir,
original_prefix,
original_prefix,
args.output_format,
show_prediction=args.show_prediction,
)


if __name__ == "__main__":
Expand Down

0 comments on commit d869d81

Please sign in to comment.