Skip to content

Commit

Permalink
separate out ablations
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Feb 11, 2025
1 parent 172fb84 commit 716b7b3
Show file tree
Hide file tree
Showing 11 changed files with 2,357 additions and 0 deletions.
Binary file added ablations/.DS_Store
Binary file not shown.
935 changes: 935 additions & 0 deletions ablations/ablation.py

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions ablations/example.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <math.h>
#include <stdint.h>
#define TRUE 1
#define FALSE 0

// ## PRE v: 20, 20000
// ## PRE T: -30, 50
// ## PRE u: -100, 100
__attribute__((noinline))
double example(double u, double v, double T) {
double t1 = 331.4 + (0.6 * T);
return (-t1 * v) / ((t1 + u) * (t1 + u));
}
165 changes: 165 additions & 0 deletions ablations/fp-logger.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#include <cassert>
#include <cmath>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <limits>
#include <string>
#include <unordered_map>
#include <vector>

#include "fp-logger.hpp"

class ValueInfo {
public:
double minRes = std::numeric_limits<double>::max();
double maxRes = std::numeric_limits<double>::lowest();
std::vector<double> minOperands;
std::vector<double> maxOperands;
unsigned executions = 0;
double logSum = 0.0;
unsigned logCount = 0;

void update(double res, const double *operands, unsigned numOperands) {
minRes = std::min(minRes, res);
maxRes = std::max(maxRes, res);
if (minOperands.empty()) {
minOperands.resize(numOperands, std::numeric_limits<double>::max());
maxOperands.resize(numOperands, std::numeric_limits<double>::lowest());
}
for (unsigned i = 0; i < numOperands; ++i) {
minOperands[i] = std::min(minOperands[i], operands[i]);
maxOperands[i] = std::max(maxOperands[i], operands[i]);
}
++executions;

if (!std::isnan(res)) {
logSum += std::log1p(std::fabs(res));
++logCount;
}
}

double getGeometricAverage() const {
if (logCount == 0) {
return 0.;
}
return std::expm1(logSum / logCount);
}
};

class ErrorInfo {
public:
double minErr = std::numeric_limits<double>::max();
double maxErr = std::numeric_limits<double>::lowest();

void update(double err) {
minErr = std::min(minErr, err);
maxErr = std::max(maxErr, err);
}
};

class GradInfo {
public:
double logSum = 0.0;
unsigned count = 0;

void update(double grad) {
if (!std::isnan(grad)) {
logSum += std::log1p(std::fabs(grad));
++count;
}
}

double getGeometricAverage() const {
if (count == 0) {
return 0.;
}
return std::expm1(logSum / count);
}
};

class Logger {
private:
std::unordered_map<std::string, ValueInfo> valueInfo;
std::unordered_map<std::string, ErrorInfo> errorInfo;
std::unordered_map<std::string, GradInfo> gradInfo;

public:
void updateValue(const std::string &id, double res, unsigned numOperands,
const double *operands) {
auto &info = valueInfo.emplace(id, ValueInfo()).first->second;
info.update(res, operands, numOperands);
}

void updateError(const std::string &id, double err) {
auto &info = errorInfo.emplace(id, ErrorInfo()).first->second;
info.update(err);
}

void updateGrad(const std::string &id, double grad) {
auto &info = gradInfo.emplace(id, GradInfo()).first->second;
info.update(grad);
}

void print() const {
std::cout << std::scientific
<< std::setprecision(std::numeric_limits<double>::max_digits10);

for (const auto &pair : valueInfo) {
const auto &id = pair.first;
const auto &info = pair.second;
std::cout << "Value:" << id << "\n";
std::cout << "\tMinRes = " << info.minRes << "\n";
std::cout << "\tMaxRes = " << info.maxRes << "\n";
std::cout << "\tExecutions = " << info.executions << "\n";
std::cout << "\tGeometric Average = " << info.getGeometricAverage()
<< "\n";
for (unsigned i = 0; i < info.minOperands.size(); ++i) {
std::cout << "\tOperand[" << i << "] = [" << info.minOperands[i] << ", "
<< info.maxOperands[i] << "]\n";
}
}

for (const auto &pair : errorInfo) {
const auto &id = pair.first;
const auto &info = pair.second;
std::cout << "Error:" << id << "\n";
std::cout << "\tMinErr = " << info.minErr << "\n";
std::cout << "\tMaxErr = " << info.maxErr << "\n";
}

for (const auto &pair : gradInfo) {
const auto &id = pair.first;
const auto &info = pair.second;
std::cout << "Grad:" << id << "\n";
std::cout << "\tGrad = " << info.getGeometricAverage() << "\n";
}
}
};

Logger *logger = nullptr;

void initializeLogger() { logger = new Logger(); }

void destroyLogger() {
delete logger;
logger = nullptr;
}

void printLogger() { logger->print(); }

void enzymeLogError(const char *id, double err) {
assert(logger && "Logger is not initialized");
logger->updateError(id, err);
}

void enzymeLogGrad(const char *id, double grad) {
assert(logger && "Logger is not initialized");
logger->updateGrad(id, grad);
}

void enzymeLogValue(const char *id, double res, unsigned numOperands,
double *operands) {
assert(logger && "Logger is not initialized");
logger->updateValue(id, res, numOperands, operands);
}
8 changes: 8 additions & 0 deletions ablations/fp-logger.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
void initializeLogger();
void destroyLogger();
void printLogger();

void enzymeLogError(const char *id, double err);
void enzymeLogGrad(const char *id, double grad);
void enzymeLogValue(const char *id, double res, unsigned numOperands,
double *operands);
194 changes: 194 additions & 0 deletions ablations/fpopt-baseline-generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
#!/usr/bin/env python3

import os
import sys
import re
import numpy as np

DEFAULT_NUM_SAMPLES = 10000
DEFAULT_REGEX = "ex\\d+"

np.random.seed(42)


def parse_bound(bound):
if "/" in bound:
numerator, denominator = map(float, bound.split("/"))
return numerator / denominator
return float(bound)


def parse_c_file(filepath, func_regex):
with open(filepath, "r") as file:
content = file.read()

pattern = re.compile(rf"(?s)(// ## PRE(?:.*?\n)+?)\s*([\w\s\*]+?)\s+({func_regex})\s*\(([^)]*)\)")

matches = pattern.findall(content)

if not matches:
exit(f"No functions found with the regex: {func_regex}")

functions = []

for comments, return_type, func_name, params in matches:
param_comments = re.findall(r"// ## PRE (\w+):\s*([-+.\d/]+),\s*([-+.\d/]+)", comments)
bounds = {
name: {
"min": parse_bound(min_val),
"max": parse_bound(max_val),
}
for name, min_val, max_val in param_comments
}
params = [param.strip() for param in params.split(",") if param.strip()]
functions.append((func_name, bounds, params, return_type.strip()))

return functions


def create_baseline_functions(functions):
baseline_code = []
for func_name, bounds, params, return_type in functions:
param_list = ", ".join(params)
baseline_func_name = f"baseline_{func_name}"
baseline_code.append(f"__attribute__((noinline))\n{return_type} {baseline_func_name}({param_list}) {{")
baseline_code.append(" return 42.0;")
baseline_code.append("}")
baseline_code.append("")
return "\n".join(baseline_code)


def create_baseline_driver_function(functions, num_samples_per_func):
driver_code = [
"#include <iostream>",
"#include <random>",
"#include <cstring>",
"#include <chrono>",
]

driver_code.append("#include <fstream>")
driver_code.append("#include <limits>")
driver_code.append("#include <iomanip>")
driver_code.append("#include <string>")

driver_code.append("")
driver_code.append("int main(int argc, char* argv[]) {")
driver_code.append(' std::string output_path = "";')
driver_code.append("")
driver_code.append(" for (int i = 1; i < argc; ++i) {")
driver_code.append(' if (std::strcmp(argv[i], "--output-path") == 0) {')
driver_code.append(" if (i + 1 < argc) {")
driver_code.append(" output_path = argv[i + 1];")
driver_code.append(" i++;")
driver_code.append(" } else {")
driver_code.append(' std::cerr << "Error: --output-path requires a path argument." << std::endl;')
driver_code.append(" return 1;")
driver_code.append(" }")
driver_code.append(" }")
driver_code.append(" }")
driver_code.append("")
driver_code.append(" bool save_outputs = !output_path.empty();")
driver_code.append("")
driver_code.append(" std::mt19937 gen(42);")
driver_code.append("")
driver_code.append(" std::ofstream ofs;")
driver_code.append(" if (save_outputs) {")
driver_code.append(" ofs.open(output_path);")
driver_code.append(" if (!ofs) {")
driver_code.append(' std::cerr << "Failed to open output file: " << output_path << std::endl;')
driver_code.append(" return 1;")
driver_code.append(" }")
driver_code.append(" }")
driver_code.append("")

for func_name, bounds, params, return_type in functions:
for param in params:
param_tokens = param.strip().split()
if len(param_tokens) >= 2:
param_name = param_tokens[-1]
else:
exit(f"Cannot parse parameter: {param}")
try:
min_val = bounds[param_name]["min"]
max_val = bounds[param_name]["max"]
except KeyError:
exit(
f"WARNING: Bounds not found for {param_name} in function {func_name}, manually specify the bounds."
)
dist_name = f"{func_name}_{param_name}_dist"
driver_code.append(f" std::uniform_real_distribution<{return_type}> {dist_name}({min_val}, {max_val});")
driver_code.append("")

driver_code.append(" double sum = 0.;")
driver_code.append("")

driver_code.append(" auto start_time = std::chrono::high_resolution_clock::now();")
driver_code.append("")

for func_name, bounds, params, return_type in functions:
baseline_func_name = f"baseline_{func_name}"
driver_code.append(f" for (int i = 0; i < {num_samples_per_func}; ++i) {{")

call_params = []
for param in params:
param_tokens = param.strip().split()
if len(param_tokens) >= 2:
param_name = param_tokens[-1]
else:
exit(f"Cannot parse parameter: {param}")
dist_name = f"{func_name}_{param_name}_dist"
param_value = f"{dist_name}(gen)"
call_params.append(param_value)

driver_code.append(f" double res = {baseline_func_name}({', '.join(call_params)});")
driver_code.append(" sum += res;")

driver_code.append(" if (save_outputs) {")
driver_code.append(
' ofs << std::setprecision(std::numeric_limits<double>::digits10 + 1) << res << "\\n";'
)
driver_code.append(" }")
driver_code.append(" }")
driver_code.append("")

driver_code.append(' std::cout << "Sum: " << sum << std::endl;')
driver_code.append(" auto end_time = std::chrono::high_resolution_clock::now();")
driver_code.append(" std::chrono::duration<double> elapsed = end_time - start_time;")
driver_code.append(' std::cout << "Total runtime: " << elapsed.count() << " seconds\\n";')
driver_code.append("")

driver_code.append(" if (save_outputs) {")
driver_code.append(" ofs.close();")
driver_code.append(" }")
driver_code.append("")

driver_code.append(" return 0;")
driver_code.append("}")
return "\n".join(driver_code)


def main():
if len(sys.argv) < 3:
exit(
"Usage: fpopt-baseline-generator.py <source_path> <dest_path> [func_regex] [num_samples_per_func (default: 10000)]"
)

source_path = sys.argv[1]
dest_path = sys.argv[2]
func_regex = sys.argv[3] if len(sys.argv) > 3 else DEFAULT_REGEX
num_samples_per_func = int(sys.argv[4]) if len(sys.argv) > 4 else DEFAULT_NUM_SAMPLES

functions = parse_c_file(source_path, func_regex)
baseline_functions_code = create_baseline_functions(functions)
driver_code = create_baseline_driver_function(functions, num_samples_per_func)

with open(dest_path, "w") as new_file:
new_file.write(baseline_functions_code)
new_file.write("\n\n")
new_file.write(driver_code)

print(f"Baseline code written to the new file: {dest_path}")


if __name__ == "__main__":
main()
Loading

0 comments on commit 716b7b3

Please sign in to comment.