diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 414374dd779..f6df56ff257 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3833,6 +3833,12 @@ bool GradientUtils::legalRecompute(const Value *val, reverse); // TODO ADD && !TR.intType(getOriginal(dli), // /*mustfind*/false).isPossibleFloat(); } + if (auto ci = dyn_cast(uiv)) { + auto called = ci->getCalledFunction(); + if (ci->hasFnAttr("enzyme_shouldrecompute") || + (called && called->hasFnAttribute("enzyme_shouldrecompute"))) + return true; + } if (phi->getNumIncomingValues() == 0) { return false; }