From 9076be55bd5c781d667174395ad8a7efd95744ca Mon Sep 17 00:00:00 2001 From: Brant-Skywalker Date: Fri, 6 Dec 2024 15:01:08 -0600 Subject: [PATCH] updated cost model --- experiments/ablation.py | 15 ++++++---- microbm/cm.csv | 66 +++++++++++++++++++++-------------------- microbm/microbm.py | 29 ++++++++++-------- 3 files changed, 60 insertions(+), 50 deletions(-) diff --git a/experiments/ablation.py b/experiments/ablation.py index b5016dc..cbf66a4 100644 --- a/experiments/ablation.py +++ b/experiments/ablation.py @@ -57,6 +57,10 @@ "-mllvm", "--fpopt-enable-pt", "-mllvm", + "--fpopt-cost-dom-thres=0", + "-mllvm", + "--fpopt-acc-dom-thres=0", + "-mllvm", "--fpopt-comp-cost-budget=0", "-mllvm", "--herbie-num-threads=8", @@ -619,16 +623,17 @@ def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_fo ax1.set_xlabel("Runtimes (seconds)") ax1.set_ylabel("Relative Errors (%)") ax1.set_title("Pareto Fronts for Different widen-range Values") - ax1.set_yscale("log") + ax1.set_yscale("symlog", linthresh=1e-15) + ax1.set_ylim(bottom=0) ax1.legend() ax1.grid(True) if ax2 is not None: - ax2.set_xlabel("Predicted Cost") - ax2.set_ylabel("Predicted Error (%)") + ax2.set_xlabel("Cost Budget") + ax2.set_ylabel("Predicted Error") ax2.set_title("Predicted Pareto Fronts") - ax2.set_yscale("log") - ax2.legend() + ax2.set_yscale("symlog", linthresh=1e-15) + ax2.set_ylim(bottom=-1e-13) ax2.grid(True) plot_filename = os.path.join(plots_dir, f"{prefix}ablation_widen_range_pareto_front.{output_format}") diff --git a/microbm/cm.csv b/microbm/cm.csv index 2122011..8760ada 100644 --- a/microbm/cm.csv +++ b/microbm/cm.csv @@ -1,23 +1,24 @@ -fneg,float,22 -fadd,float,22 -fsub,float,22 -fmul,float,22 -fdiv,float,22 -fcmp,float,16 -fpext_float_to_double,float,22 -sin,float,100 -cos,float,103 -tan,float,246 -exp,float,142 -log,float,83 -sqrt,float,33 -expm1,float,106 -log1p,float,110 -cbrt,float,584 -pow,float,159 -fabs,float,20 -hypot,float,608 -fma,float,75 +fneg,float,19 +fadd,float,19 +fsub,float,19 +fmul,float,19 +fdiv,float,19 +fcmp,float,15 +fpext_float_to_double,float,19 +fmuladd,float,19 +sin,float,91 +cos,float,98 +tan,float,243 +exp,float,125 +log,float,76 +sqrt,float,30 +expm1,float,76 +log1p,float,79 +cbrt,float,618 +pow,float,177 +fabs,float,19 +hypot,float,641 +fma,float,74 fneg,double,19 fadd,double,19 fsub,double,19 @@ -25,16 +26,17 @@ fmul,double,19 fdiv,double,28 fcmp,double,15 fptrunc_double_to_float,double,19 -sin,double,829 -cos,double,859 -tan,double,998 -exp,double,192 -log,double,104 +fmuladd,double,19 +sin,double,830 +cos,double,828 +tan,double,970 +exp,double,166 +log,double,97 sqrt,double,52 -expm1,double,75 -log1p,double,118 -cbrt,double,244 -pow,double,225 -fabs,double,20 -hypot,double,382 -fma,double,74 \ No newline at end of file +expm1,double,76 +log1p,double,115 +cbrt,double,237 +pow,double,279 +fabs,double,19 +hypot,double,383 +fma,double,75 diff --git a/microbm/microbm.py b/microbm/microbm.py index 15c7ca3..08ab212 100644 --- a/microbm/microbm.py +++ b/microbm/microbm.py @@ -10,8 +10,11 @@ random.seed(42) +# FAST_MATH_FLAG = "fast" +FAST_MATH_FLAG = "reassoc nsz arcp contract afn" + instructions = ["fneg", "fadd", "fsub", "fmul", "fdiv", "fcmp", "fptrunc", "fpext"] -functions = ["sin", "cos", "tan", "exp", "log", "sqrt", "expm1", "log1p", "cbrt", "pow", "fabs", "hypot", "fma"] +functions = ["fmuladd", "sin", "cos", "tan", "exp", "log", "sqrt", "expm1", "log1p", "cbrt", "pow", "fabs", "hypot", "fma"] precisions = ["float", "double"] # precisions = ["bf16", "half", "float", "double", "fp80", "fp128"] @@ -250,7 +253,7 @@ def generate_llvm_code(instruction, src_precision, dst_precision, iterations): for idx, hex_a in enumerate(hex_fps): code += f" %result{idx} = fptrunc {src_type} {hex_a} to {dst_type}\n" - code += f" %acc_val{idx+1} = fadd fast {dst_type} %acc_val{idx}, %result{idx}\n" + code += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {dst_type} %acc_val{idx}, %result{idx}\n" code += f""" store {dst_type} %acc_val{len(hex_fps)}, {dst_type}* %acc @@ -291,7 +294,7 @@ def generate_llvm_code(instruction, src_precision, dst_precision, iterations): for idx, hex_a in enumerate(hex_fps): code += f" %result{idx} = fpext {src_type} {hex_a} to {dst_type}\n" - code += f" %acc_val{idx+1} = fadd fast {dst_type} %acc_val{idx}, %result{idx}\n" + code += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {dst_type} %acc_val{idx}, %result{idx}\n" code += f""" store {dst_type} %acc_val{len(hex_fps)}, {dst_type}* %acc @@ -342,8 +345,8 @@ def generate_llvm_code_other(instruction, precision, iterations): """ for idx, (hex_a, hex_b) in enumerate(hex_pairs): - code += f" %result{idx} = {op} fast {llvm_type} {hex_a}, {hex_b}\n" - code += f" %acc_val{idx+1} = fadd fast {llvm_type} %acc_val{idx}, %result{idx}\n" + code += f" %result{idx} = {op} {FAST_MATH_FLAG} {llvm_type} {hex_a}, {hex_b}\n" + code += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" code += f""" store {llvm_type} %acc_val{len(hex_pairs)}, {llvm_type}* %acc @@ -383,8 +386,8 @@ def generate_llvm_code_other(instruction, precision, iterations): """ for idx, hex_a in enumerate(hex_fps): - code += f" %result{idx} = fneg fast {llvm_type} {hex_a}\n" - code += f" %acc_val{idx+1} = fadd fast {llvm_type} %acc_val{idx}, %result{idx}\n" + code += f" %result{idx} = fneg {FAST_MATH_FLAG} {llvm_type} {hex_a}\n" + code += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" code += f""" store {llvm_type} %acc_val{len(hex_fps)}, {llvm_type}* %acc @@ -425,7 +428,7 @@ def generate_llvm_code_other(instruction, precision, iterations): b = generate_random_fp(precision) hex_a = float_to_llvm_hex(a, precision) hex_b = float_to_llvm_hex(b, precision) - code += f" %cmp{idx} = fcmp fast olt {llvm_type} {hex_a}, {hex_b}\n" + code += f" %cmp{idx} = fcmp {FAST_MATH_FLAG} olt {llvm_type} {hex_a}, {hex_b}\n" code += f" %cmp_int{idx} = zext i1 %cmp{idx} to i32\n" code += f" %acc_val0 = load i32, i32* %acc\n" @@ -465,19 +468,19 @@ def generate_llvm_function_call(function_name, precision, iterations): function_intrinsic = f"llvm.pow.{intrinsic_suffix}" code += f"declare {llvm_type} @{function_intrinsic}({llvm_type}, {llvm_type})\n" function_call_template = ( - f"call fast {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}}, {llvm_type} {{arg2}})" + f"call {FAST_MATH_FLAG} {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}}, {llvm_type} {{arg2}})" ) elif function_name == "fmuladd": function_intrinsic = f"llvm.fmuladd.{intrinsic_suffix}" code += f"declare {llvm_type} @{function_intrinsic}({llvm_type}, {llvm_type}, {llvm_type})\n" - function_call_template = f"call fast {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}}, {llvm_type} {{arg2}}, {llvm_type} {{arg3}})" + function_call_template = f"call {FAST_MATH_FLAG} {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}}, {llvm_type} {{arg2}}, {llvm_type} {{arg3}})" elif function_name in functions_with_intrinsics: function_intrinsic = f"llvm.{function_name}.{intrinsic_suffix}" code += f"declare {llvm_type} @{function_intrinsic}({llvm_type})\n" - function_call_template = f"call fast {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}})" + function_call_template = f"call {FAST_MATH_FLAG} {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}})" else: code += f"declare {llvm_type} @{function_name}({llvm_type})\n" - function_call_template = f"call fast {llvm_type} @{function_name}({llvm_type} {{arg1}})" + function_call_template = f"call {FAST_MATH_FLAG} {llvm_type} @{function_name}({llvm_type} {{arg1}})" code += f""" define i32 @main() {{ entry: @@ -522,7 +525,7 @@ def generate_llvm_function_call(function_name, precision, iterations): function_call = function_call_template.format(arg1=hex_a) code += f" %result{idx} = {function_call}\n" - code += f" %acc_val{idx+1} = fadd fast {llvm_type} %acc_val{idx}, %result{idx}\n" + code += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" code += f""" store {llvm_type} %acc_val{unrolled}, {llvm_type}* %acc