Skip to content

Commit

Permalink
geomean fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Feb 13, 2025
1 parent babcc9f commit 9db684c
Showing 1 changed file with 84 additions and 41 deletions.
125 changes: 84 additions & 41 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,60 @@ static cl::opt<double> FPOptAccuracyDominanceThreshold(
cl::desc("The threshold for accuracy dominance in DP solver"));
}

// https://arxiv.org/pdf/1806.06403
double geomean(const SmallVectorImpl<double> &dataset, double epsilon = 1e-5) {
std::vector<double> dataset_nozeros;
for (double x : dataset) {
if (x != 0.0)
dataset_nozeros.push_back(x);
}

if (dataset_nozeros.empty()) {
return 0.0;
}

double sum_log = 0.0;
for (double x : dataset_nozeros) {
sum_log += std::log(x);
}
double geomeanNozeros = std::exp(sum_log / dataset_nozeros.size());

double min_val =
*std::min_element(dataset_nozeros.begin(), dataset_nozeros.end());
double deltamin = 0.0;
double deltamax = std::max(geomeanNozeros - min_val, 0.0);
double delta = (deltamin + deltamax) / 2.0;
double epsilon_threshold = epsilon * geomeanNozeros;

auto compute_auxExp = [&](double d) -> double {
double sum = 0.0;
for (double x : dataset_nozeros) {
sum += std::log(x + d);
}
return std::exp(sum / dataset_nozeros.size()) - d;
};

double auxExp = compute_auxExp(delta);

while ((auxExp - geomeanNozeros) > epsilon_threshold) {
if (auxExp < geomeanNozeros)
deltamin = delta;
else
deltamax = delta;
delta = (deltamin + deltamax) / 2.0;
auxExp = compute_auxExp(delta);
}

double sum_log_all = 0.0;
for (double x : dataset) {
sum_log_all += std::log(x + delta);
}
double gmeanE = std::exp(sum_log_all / dataset.size()) - delta;

assert(!std::isnan(gmeanE) && !std::isinf(gmeanE));
return gmeanE;
}

class FPNode {
public:
enum class NodeType { Node, LLValue, Const };
Expand Down Expand Up @@ -1078,6 +1132,7 @@ struct PTCandidate {
InstructionCost CompCost;
std::string desc;
std::unordered_map<FPNode *, double> perOutputAccCost;
std::unordered_map<FPNode *, SmallVector<double, 4>> errors;

// TODO:
explicit PTCandidate(SmallVector<PrecisionChange> changes,
Expand Down Expand Up @@ -3110,6 +3165,7 @@ class ApplicableFPCC {
InstructionCost initialCompCost;
unsigned executions; // Requires manual initialization
std::unordered_map<FPNode *, double> perOutputInitialAccCost;
std::unordered_map<FPNode *, SmallVector<double, 4>> errors;

SmallVector<PTCandidate, 8> candidates;

Expand Down Expand Up @@ -3294,9 +3350,8 @@ void setUnifiedAccuracyCost(

SmallVector<double, 4> goldVals;
goldVals.resize(FPOptNumSamples);
double initAC = 0.;
SmallVector<double, 4> errors;

unsigned numValidSamples = 0;
for (const auto &pair : enumerate(sampledPoints)) {
ArrayRef<FPNode *> outputs = {valueToNodeMap[AO.oldOutput].get()};
SmallVector<double, 1> results;
Expand All @@ -3309,25 +3364,24 @@ void setUnifiedAccuracyCost(
double realVal = results[0];
// llvm::errs() << "DEBUG AO real value: " << realVal << "\n";

if (!std::isnan(goldVal) && !std::isnan(realVal)) {
initAC += std::log1p(std::fabs(goldVal - realVal));
numValidSamples++;
double error = std::fabs(goldVal - realVal);
if (!std::isnan(error) && !std::isinf(error)) {
errors.push_back(error);
}
}

AO.initialAccCost = std::expm1(initAC / numValidSamples) * std::fabs(AO.grad);
assert(!errors.empty() && "No valid samples for AO -- try increasing the "
"number of samples");
AO.initialAccCost = geomean(errors) * std::fabs(AO.grad);
// llvm::errs() << "DEBUG calculated AO initial accuracy cost: "
// << AO.initialAccCost << "\n";
assert(numValidSamples && "No valid samples for AO -- try increasing the "
"number of samples");
assert(!std::isnan(AO.initialAccCost));

for (auto &candidate : AO.candidates) {
const auto &expr = candidate.expr;
auto parsedNode = parseHerbieExpr(expr, valueToNodeMap, symbolToValueMap);
double ac = 0.;

numValidSamples = 0;
SmallVector<double, 4> errors;
for (const auto &pair : enumerate(sampledPoints)) {
// Compute the "gold" value & real value for each sampled point
// Compute an average of (difference * gradient)
Expand All @@ -3352,15 +3406,14 @@ void setUnifiedAccuracyCost(

// llvm::errs() << "Real value: " << realVal << "\n";
double goldVal = goldVals[pair.index()];
if (!std::isnan(goldVal) && !std::isnan(realVal)) {
ac += std::log1p(std::fabs(goldVal - realVal));
numValidSamples++;
double error = std::fabs(goldVal - realVal);
if (!std::isnan(error) && !std::isinf(error)) {
errors.push_back(error);
}
}
assert(numValidSamples && "No valid samples for AO -- try increasing the "
assert(!errors.empty() && "No valid samples for AO -- try increasing the "
"number of samples");
candidate.accuracyCost =
std::expm1(ac / numValidSamples) * std::fabs(AO.grad);
candidate.accuracyCost = geomean(errors) * std::fabs(AO.grad);
assert(!std::isnan(candidate.accuracyCost));
}
}
Expand All @@ -3387,11 +3440,6 @@ void setUnifiedAccuracyCost(
outputs.push_back(valueToNodeMap[output].get());
}

std::unordered_map<FPNode *, unsigned> numValidSamplesPerOutput;
for (auto *output : outputs) {
numValidSamplesPerOutput[output] = 0;
}

for (const auto &pair : enumerate(sampledPoints)) {
SmallVector<double, 8> results;

Expand All @@ -3408,24 +3456,22 @@ void setUnifiedAccuracyCost(
for (const auto &[output, result] : zip(outputs, results)) {
// llvm::errs() << "DEBUG ACC real value: " << result << "\n";
double goldVal = goldVals[output][pair.index()];
if (!std::isnan(goldVal) && !std::isnan(result)) {
double diff = std::fabs(goldVal - result);
ACC.perOutputInitialAccCost[output] += std::log1p(diff);
numValidSamplesPerOutput[output]++;
double error = std::fabs(goldVal - result);
if (!std::isnan(error) && !std::isinf(error)) {
ACC.errors[output].push_back(error);
}
}
}

// Normalize accuracy costs and compute aggregated initialAccCost
ACC.initialAccCost = 0.0;
for (auto *output : outputs) {
unsigned numValidSamples = numValidSamplesPerOutput[output];
assert(numValidSamples && "No valid samples for at least one output node "
"-- try increasing the number of samples");
assert(!ACC.errors[output].empty() &&
"No valid samples for at least one output node "
"-- try increasing the number of samples");
// Local error --> global error
ACC.perOutputInitialAccCost[output] =
std::expm1(ACC.perOutputInitialAccCost[output] / numValidSamples) *
std::fabs(output->grad);
geomean(ACC.errors[output]) * std::fabs(output->grad);
// llvm::errs() << "DEBUG calculated ACC per output initial accuracy cost: "
// << ACC.perOutputInitialAccCost[output] << "\n";
ACC.initialAccCost += ACC.perOutputInitialAccCost[output];
Expand All @@ -3437,7 +3483,6 @@ void setUnifiedAccuracyCost(
std::unordered_map<FPNode *, unsigned> numValidSamplesPerOutput;
for (auto *output : outputs) {
candidate.perOutputAccCost[output] = 0.;
numValidSamplesPerOutput[output] = 0;
}

for (const auto &pair : enumerate(sampledPoints)) {
Expand All @@ -3446,25 +3491,23 @@ void setUnifiedAccuracyCost(

for (const auto &[output, result] : zip(outputs, results)) {
double goldVal = goldVals[output][pair.index()];
if (!std::isnan(goldVal) && !std::isnan(result)) {
double diff = std::fabs(goldVal - result);
// Sum up local errors
candidate.perOutputAccCost[output] += std::log1p(diff);
numValidSamplesPerOutput[output]++;
double error = std::fabs(goldVal - result);
if (!std::isnan(error) && !std::isinf(error)) {
// Record local errors
candidate.errors[output].push_back(error);
}
}
}

// Normalize accuracy costs and compute aggregated accuracyCost
candidate.accuracyCost = 0.0;
for (auto *output : outputs) {
unsigned numValidSamples = numValidSamplesPerOutput[output];
assert(numValidSamples && "No valid samples for output -- try increasing "
"the number of samples");
assert(!candidate.errors[output].empty() &&
"No valid samples for output -- try increasing "
"the number of samples");
// Local error --> global error
candidate.perOutputAccCost[output] =
std::expm1(candidate.perOutputAccCost[output] / numValidSamples) *
std::fabs(output->grad);
geomean(candidate.errors[output]) * std::fabs(output->grad);
// llvm::errs()
// << "DEBUG calculated ACC per output candidate accuracy cost: "
// << candidate.perOutputAccCost[output] << "\n";
Expand Down

0 comments on commit 9db684c

Please sign in to comment.