Skip to content

Commit

Permalink
fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Oct 31, 2024
1 parent e1e1a0a commit 23ea793
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 31 deletions.
2 changes: 1 addition & 1 deletion experiments/fpopt-baseline-generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def create_baseline_driver_function(functions, num_samples_per_func):
f"WARNING: Bounds not found for {param_name} in function {func_name}, manually specify the bounds."
)
dist_name = f"{func_name}_{param_name}_dist"
driver_code.append(f" std::uniform_real_distribution<double> {dist_name}({min_val}, {max_val});")
driver_code.append(f" std::uniform_real_distribution<{return_type}> {dist_name}({min_val}, {max_val});")
driver_code.append("")

driver_code.append(" double sum = 0.;")
Expand Down
2 changes: 1 addition & 1 deletion experiments/fpopt-golden-driver-generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def create_driver_function(functions, num_samples_per_func):
f"WARNING: Bounds not found for {param_name} in function {func_name}, manually specify the bounds."
)
dist_name = f"{func_name}_{param_name}_dist"
driver_code.append(f" std::uniform_real_distribution<double> {dist_name}({min_val}, {max_val});")
driver_code.append(f" std::uniform_real_distribution<{return_type}> {dist_name}({min_val}, {max_val});")
driver_code.append("")

driver_code.append(" double sum = 0.;")
Expand Down
2 changes: 1 addition & 1 deletion experiments/fpopt-logged-driver-generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create_driver_function(functions, num_samples_per_func):
f"WARNING: Bounds not found for {param_name} in function {func_name}, manually specify the bounds."
)
dist_name = f"{func_name}_{param_name}_dist"
driver_code.append(f" std::uniform_real_distribution<double> {dist_name}({min_val}, {max_val});")
driver_code.append(f" std::uniform_real_distribution<{return_type}> {dist_name}({min_val}, {max_val});")
driver_code.append("")

driver_code.append(" double sum = 0.;")
Expand Down
2 changes: 1 addition & 1 deletion experiments/fpopt-original-driver-generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def create_driver_function(functions, num_samples_per_func):
f"WARNING: Bounds not found for {param_name} in function {func_name}, manually specify the bounds."
)
dist_name = f"{func_name}_{param_name}_dist"
driver_code.append(f" std::uniform_real_distribution<double> {dist_name}({min_val}, {max_val});")
driver_code.append(f" std::uniform_real_distribution<{return_type}> {dist_name}({min_val}, {max_val});")
driver_code.append("")

driver_code.append(" double sum = 0.;")
Expand Down
6 changes: 5 additions & 1 deletion experiments/run-all.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def process_function_task(task):
try:
subprocess.check_call(["python3", "run.py", "--prefix", prefix])
except subprocess.CalledProcessError as e:
print(f"Error running run.py for function {func_name} in base {base_name}")
print(f"Error running run.py for function {func_name} in base {base_name}. Retrying with --disable-preopt.")
try:
subprocess.check_call(["python3", "run.py", "--prefix", prefix, "--disable-preopt"])
except subprocess.CalledProcessError as e:
print(f"Error running run.py with --disable-preopt for function {func_name} in base {base_name}")


def main():
Expand Down
57 changes: 31 additions & 26 deletions experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,6 @@
"-fuse-ld=lld",
]

FPOPTFLAGS_BASE = [
"-mllvm",
"--enzyme-enable-fpopt",
"-mllvm",
"--enzyme-print-herbie",
"-mllvm",
"--enzyme-print-fpopt",
"-mllvm",
"--fpopt-log-path=example.txt",
"-mllvm",
"--fpopt-target-func-regex=example",
"-mllvm",
"--fpopt-enable-solver",
"-mllvm",
"--fpopt-enable-pt",
"-mllvm",
"--fpopt-comp-cost-budget=0",
"-mllvm",
"--fpopt-num-samples=1000",
"-mllvm",
"--fpopt-cost-model-path=../microbm/cm.csv",
# "-mllvm",
# "--herbie-disable-regime",
# "-mllvm",
# "--herbie-disable-taylor"
]

SRC = "example.c"
LOGGER = "fp-logger.cpp"
Expand Down Expand Up @@ -881,8 +855,39 @@ def main():
parser.add_argument("--plot-only", action="store_true", help="Plot results from existing data")
parser.add_argument("--output-format", type=str, default="png", help="Output format for plots (e.g., png, pdf)")
parser.add_argument("--analytics", action="store_true", help="Run analytics on saved data")
parser.add_argument("--disable-preopt", action="store_true", help="Disable Enzyme preoptimization")
args = parser.parse_args()

global FPOPTFLAGS_BASE # Ensure global scope
FPOPTFLAGS_BASE = [
"-mllvm",
"--enzyme-enable-fpopt",
"-mllvm",
"--enzyme-print-herbie",
"-mllvm",
"--enzyme-print-fpopt",
"-mllvm",
"--fpopt-log-path=example.txt",
"-mllvm",
"--fpopt-target-func-regex=example",
"-mllvm",
"--fpopt-enable-solver",
"-mllvm",
"--fpopt-enable-pt",
"-mllvm",
"--fpopt-comp-cost-budget=0",
"-mllvm",
"--fpopt-num-samples=1000",
"-mllvm",
"--fpopt-cost-model-path=../microbm/cm.csv",
# "-mllvm",
# "--herbie-disable-regime",
# "-mllvm",
# "--herbie-disable-taylor"
]
if args.disable_preopt:
FPOPTFLAGS_BASE.extend(["-mllvm", "--enzyme-preopt=0"])

prefix = args.prefix
if not prefix.endswith("-"):
prefix += "-"
Expand Down

0 comments on commit 23ea793

Please sign in to comment.