Skip to content

Commit

Permalink
updated cost model
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Dec 6, 2024
1 parent d869d81 commit 9076be5
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 50 deletions.
15 changes: 10 additions & 5 deletions experiments/ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
"-mllvm",
"--fpopt-enable-pt",
"-mllvm",
"--fpopt-cost-dom-thres=0",
"-mllvm",
"--fpopt-acc-dom-thres=0",
"-mllvm",
"--fpopt-comp-cost-budget=0",
"-mllvm",
"--herbie-num-threads=8",
Expand Down Expand Up @@ -619,16 +623,17 @@ def plot_ablation_results(tmp_dir, plots_dir, original_prefix, prefix, output_fo
ax1.set_xlabel("Runtimes (seconds)")
ax1.set_ylabel("Relative Errors (%)")
ax1.set_title("Pareto Fronts for Different widen-range Values")
ax1.set_yscale("log")
ax1.set_yscale("symlog", linthresh=1e-15)
ax1.set_ylim(bottom=0)
ax1.legend()
ax1.grid(True)

if ax2 is not None:
ax2.set_xlabel("Predicted Cost")
ax2.set_ylabel("Predicted Error (%)")
ax2.set_xlabel("Cost Budget")
ax2.set_ylabel("Predicted Error")
ax2.set_title("Predicted Pareto Fronts")
ax2.set_yscale("log")
ax2.legend()
ax2.set_yscale("symlog", linthresh=1e-15)
ax2.set_ylim(bottom=-1e-13)
ax2.grid(True)

plot_filename = os.path.join(plots_dir, f"{prefix}ablation_widen_range_pareto_front.{output_format}")
Expand Down
66 changes: 34 additions & 32 deletions microbm/cm.csv
Original file line number Diff line number Diff line change
@@ -1,40 +1,42 @@
fneg,float,22
fadd,float,22
fsub,float,22
fmul,float,22
fdiv,float,22
fcmp,float,16
fpext_float_to_double,float,22
sin,float,100
cos,float,103
tan,float,246
exp,float,142
log,float,83
sqrt,float,33
expm1,float,106
log1p,float,110
cbrt,float,584
pow,float,159
fabs,float,20
hypot,float,608
fma,float,75
fneg,float,19
fadd,float,19
fsub,float,19
fmul,float,19
fdiv,float,19
fcmp,float,15
fpext_float_to_double,float,19
fmuladd,float,19
sin,float,91
cos,float,98
tan,float,243
exp,float,125
log,float,76
sqrt,float,30
expm1,float,76
log1p,float,79
cbrt,float,618
pow,float,177
fabs,float,19
hypot,float,641
fma,float,74
fneg,double,19
fadd,double,19
fsub,double,19
fmul,double,19
fdiv,double,28
fcmp,double,15
fptrunc_double_to_float,double,19
sin,double,829
cos,double,859
tan,double,998
exp,double,192
log,double,104
fmuladd,double,19
sin,double,830
cos,double,828
tan,double,970
exp,double,166
log,double,97
sqrt,double,52
expm1,double,75
log1p,double,118
cbrt,double,244
pow,double,225
fabs,double,20
hypot,double,382
fma,double,74
expm1,double,76
log1p,double,115
cbrt,double,237
pow,double,279
fabs,double,19
hypot,double,383
fma,double,75
29 changes: 16 additions & 13 deletions microbm/microbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

random.seed(42)

# FAST_MATH_FLAG = "fast"
FAST_MATH_FLAG = "reassoc nsz arcp contract afn"

instructions = ["fneg", "fadd", "fsub", "fmul", "fdiv", "fcmp", "fptrunc", "fpext"]
functions = ["sin", "cos", "tan", "exp", "log", "sqrt", "expm1", "log1p", "cbrt", "pow", "fabs", "hypot", "fma"]
functions = ["fmuladd", "sin", "cos", "tan", "exp", "log", "sqrt", "expm1", "log1p", "cbrt", "pow", "fabs", "hypot", "fma"]

precisions = ["float", "double"]
# precisions = ["bf16", "half", "float", "double", "fp80", "fp128"]
Expand Down Expand Up @@ -250,7 +253,7 @@ def generate_llvm_code(instruction, src_precision, dst_precision, iterations):

for idx, hex_a in enumerate(hex_fps):
code += f" %result{idx} = fptrunc {src_type} {hex_a} to {dst_type}\n"
code += f" %acc_val{idx+1} = fadd fast {dst_type} %acc_val{idx}, %result{idx}\n"
code += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {dst_type} %acc_val{idx}, %result{idx}\n"

code += f"""
store {dst_type} %acc_val{len(hex_fps)}, {dst_type}* %acc
Expand Down Expand Up @@ -291,7 +294,7 @@ def generate_llvm_code(instruction, src_precision, dst_precision, iterations):

for idx, hex_a in enumerate(hex_fps):
code += f" %result{idx} = fpext {src_type} {hex_a} to {dst_type}\n"
code += f" %acc_val{idx+1} = fadd fast {dst_type} %acc_val{idx}, %result{idx}\n"
code += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {dst_type} %acc_val{idx}, %result{idx}\n"

code += f"""
store {dst_type} %acc_val{len(hex_fps)}, {dst_type}* %acc
Expand Down Expand Up @@ -342,8 +345,8 @@ def generate_llvm_code_other(instruction, precision, iterations):
"""

for idx, (hex_a, hex_b) in enumerate(hex_pairs):
code += f" %result{idx} = {op} fast {llvm_type} {hex_a}, {hex_b}\n"
code += f" %acc_val{idx+1} = fadd fast {llvm_type} %acc_val{idx}, %result{idx}\n"
code += f" %result{idx} = {op} {FAST_MATH_FLAG} {llvm_type} {hex_a}, {hex_b}\n"
code += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n"

code += f"""
store {llvm_type} %acc_val{len(hex_pairs)}, {llvm_type}* %acc
Expand Down Expand Up @@ -383,8 +386,8 @@ def generate_llvm_code_other(instruction, precision, iterations):
"""

for idx, hex_a in enumerate(hex_fps):
code += f" %result{idx} = fneg fast {llvm_type} {hex_a}\n"
code += f" %acc_val{idx+1} = fadd fast {llvm_type} %acc_val{idx}, %result{idx}\n"
code += f" %result{idx} = fneg {FAST_MATH_FLAG} {llvm_type} {hex_a}\n"
code += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n"

code += f"""
store {llvm_type} %acc_val{len(hex_fps)}, {llvm_type}* %acc
Expand Down Expand Up @@ -425,7 +428,7 @@ def generate_llvm_code_other(instruction, precision, iterations):
b = generate_random_fp(precision)
hex_a = float_to_llvm_hex(a, precision)
hex_b = float_to_llvm_hex(b, precision)
code += f" %cmp{idx} = fcmp fast olt {llvm_type} {hex_a}, {hex_b}\n"
code += f" %cmp{idx} = fcmp {FAST_MATH_FLAG} olt {llvm_type} {hex_a}, {hex_b}\n"
code += f" %cmp_int{idx} = zext i1 %cmp{idx} to i32\n"

code += f" %acc_val0 = load i32, i32* %acc\n"
Expand Down Expand Up @@ -465,19 +468,19 @@ def generate_llvm_function_call(function_name, precision, iterations):
function_intrinsic = f"llvm.pow.{intrinsic_suffix}"
code += f"declare {llvm_type} @{function_intrinsic}({llvm_type}, {llvm_type})\n"
function_call_template = (
f"call fast {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}}, {llvm_type} {{arg2}})"
f"call {FAST_MATH_FLAG} {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}}, {llvm_type} {{arg2}})"
)
elif function_name == "fmuladd":
function_intrinsic = f"llvm.fmuladd.{intrinsic_suffix}"
code += f"declare {llvm_type} @{function_intrinsic}({llvm_type}, {llvm_type}, {llvm_type})\n"
function_call_template = f"call fast {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}}, {llvm_type} {{arg2}}, {llvm_type} {{arg3}})"
function_call_template = f"call {FAST_MATH_FLAG} {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}}, {llvm_type} {{arg2}}, {llvm_type} {{arg3}})"
elif function_name in functions_with_intrinsics:
function_intrinsic = f"llvm.{function_name}.{intrinsic_suffix}"
code += f"declare {llvm_type} @{function_intrinsic}({llvm_type})\n"
function_call_template = f"call fast {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}})"
function_call_template = f"call {FAST_MATH_FLAG} {llvm_type} @{function_intrinsic}({llvm_type} {{arg1}})"
else:
code += f"declare {llvm_type} @{function_name}({llvm_type})\n"
function_call_template = f"call fast {llvm_type} @{function_name}({llvm_type} {{arg1}})"
function_call_template = f"call {FAST_MATH_FLAG} {llvm_type} @{function_name}({llvm_type} {{arg1}})"
code += f"""
define i32 @main() {{
entry:
Expand Down Expand Up @@ -522,7 +525,7 @@ def generate_llvm_function_call(function_name, precision, iterations):
function_call = function_call_template.format(arg1=hex_a)

code += f" %result{idx} = {function_call}\n"
code += f" %acc_val{idx+1} = fadd fast {llvm_type} %acc_val{idx}, %result{idx}\n"
code += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n"

code += f"""
store {llvm_type} %acc_val{unrolled}, {llvm_type}* %acc
Expand Down

0 comments on commit 9076be5

Please sign in to comment.