Skip to content

Commit

Permalink
improve
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Jan 24, 2025
1 parent 72a2a94 commit 22da4f8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
30 changes: 20 additions & 10 deletions experiments/run-all.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def extract_functions_from_c_file(content, func_regex="ex\\d+"):
return functions


def process_function_task(task):
func, base_name = task
def process_function_task(func, base_name, plot_only):
func_name = func["func_name"]
return_type = func["return_type"]
params = func["params"]
Expand All @@ -77,24 +76,35 @@ def process_function_task(task):
example_c_content = includes + "\n" + func_code
with open(example_c_filepath, "w") as f:
f.write(example_c_content)

base_command = ["python3", "run.py", "--prefix", prefix]

if plot_only:
base_command.append("--plot-only")

try:
subprocess.check_call(["python3", "run.py", "--prefix", prefix])
except subprocess.CalledProcessError as e:
subprocess.check_call(base_command)
except subprocess.CalledProcessError:
print(f"Error running run.py for function {func_name} in base {base_name}. Retrying with --disable-preopt.")
retry_command = ["python3", "run.py", "--prefix", prefix, "--disable-preopt"]
if plot_only:
retry_command.append("--plot-only")
try:
subprocess.check_call(["python3", "run.py", "--prefix", prefix, "--disable-preopt"])
except subprocess.CalledProcessError as e:
subprocess.check_call(retry_command)
except subprocess.CalledProcessError:
print(f"Error running run.py with --disable-preopt for function {func_name} in base {base_name}")


def main():
parser = argparse.ArgumentParser(description="Process functions from .fpcore files.")
parser.add_argument("-j", type=int, help="Number of parallel tasks", default=1)
parser.add_argument("--regen", action="store_true", help="Force regeneration of .c files")
parser.add_argument("--plot-only", action="store_true", help="Only plot the results")
args = parser.parse_args()

num_parallel_tasks = args.j
force_regen = args.regen
plot_only = args.plot_only

source_dir = "../benchmarks"
exported_dir = "exported"
Expand Down Expand Up @@ -123,7 +133,7 @@ def main():
try:
print("Running command: ", " ".join(["racket", racket_script, fpcore_file, c_filepath]))
subprocess.check_call(["racket", racket_script, fpcore_file, c_filepath])
except subprocess.CalledProcessError as e:
except subprocess.CalledProcessError:
print(f"Error running export.rkt on {filename}")
continue

Expand All @@ -137,16 +147,16 @@ def main():
for func in functions:
func_name = func["func_name"]
print(f"Found function: {func_name}")
task = (func, base_name)
task = (func, base_name, plot_only)
tasks.append(task)

if tasks:
if num_parallel_tasks == 1:
for task in tasks:
process_function_task(task)
process_function_task(*task)
else:
with multiprocessing.Pool(num_parallel_tasks) as pool:
pool.map(process_function_task, tasks)
pool.starmap(process_function_task, tasks)
else:
print("No functions to process.")

Expand Down
16 changes: 8 additions & 8 deletions experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from concurrent.futures import ProcessPoolExecutor, as_completed

HOME = "/home/sbrantq"
ENZYME_PATH = os.path.join(HOME, "sync/Enzyme/build/Enzyme/ClangEnzyme-15.so")
LLVM_PATH = os.path.join(HOME, "llvms/llvm15/build/bin")
ENZYME_PATH = os.path.join(HOME, "sync/Enzyme/build/Enzyme/ClangEnzyme-16.so")
LLVM_PATH = os.path.join(HOME, "llvms/llvm16/build/bin")
CXX = os.path.join(LLVM_PATH, "clang++")

CXXFLAGS = [
Expand Down Expand Up @@ -443,7 +443,7 @@ def plot_results(
color_runtime = "tab:blue"
ax1.set_xlabel("Computation Cost Budget")
ax1.set_ylabel("Runtimes (seconds)", color=color_runtime)
(line1,) = ax1.plot(
(line1,) = ax1.step(
budgets, runtimes, marker="o", linestyle="-", label="Optimized Runtimes", color=color_runtime
)
if original_runtime is not None:
Expand All @@ -453,7 +453,7 @@ def plot_results(
ax2 = ax1.twinx()
color_error = "tab:green"
ax2.set_ylabel("Relative Errors (%)", color=color_error)
(line3,) = ax2.plot(
(line3,) = ax2.step(
budgets, errors, marker="s", linestyle="-", label="Optimized Relative Errors", color=color_error
)
if original_error is not None:
Expand Down Expand Up @@ -522,7 +522,7 @@ def plot_results(

pareto_front = np.array(pareto_front)

(line_pareto,) = ax3.plot(
(line_pareto,) = ax3.step(
pareto_front[:, 0], pareto_front[:, 1], linestyle="-", color="purple", label="Pareto Front"
)
ax3.set_yscale("log")
Expand Down Expand Up @@ -562,7 +562,7 @@ def plot_results(
color_runtime = "tab:blue"
ax1.set_xlabel("Computation Cost Budget")
ax1.set_ylabel("Runtimes (seconds)", color=color_runtime)
(line1,) = ax1.plot(
(line1,) = ax1.step(
budgets, runtimes, marker="o", linestyle="-", label="Optimized Runtimes", color=color_runtime
)
if original_runtime is not None:
Expand All @@ -572,7 +572,7 @@ def plot_results(
ax2 = ax1.twinx()
color_error = "tab:green"
ax2.set_ylabel("Relative Errors (%)", color=color_error)
(line3,) = ax2.plot(
(line3,) = ax2.step(
budgets, errors, marker="s", linestyle="-", label="Optimized Relative Errors", color=color_error
)
if original_error is not None:
Expand Down Expand Up @@ -631,7 +631,7 @@ def plot_results(

pareto_front = np.array(pareto_front)

(line_pareto,) = ax3.plot(
(line_pareto,) = ax3.step(
pareto_front[:, 0], pareto_front[:, 1], linestyle="-", color="purple", label="Pareto Front"
)
ax3.set_yscale("log")
Expand Down

0 comments on commit 22da4f8

Please sign in to comment.