Skip to content

Commit

Permalink
Adapt to upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 6, 2025
1 parent df197be commit 73e48de
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 85 deletions.
14 changes: 6 additions & 8 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,14 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
AttributeList AL;
AL = AL.addParamAttribute(DT->getContext(), 0,
Attribute::AttrKind::ReadOnly);
AL = AL.addParamAttribute(DT->getContext(), 0,
Attribute::AttrKind::NoCapture);
AL = addFunctionNoCapture(DT->getContext(), AL, 0);
AL =
AL.addParamAttribute(DT->getContext(), 0, Attribute::AttrKind::NoAlias);
AL =
AL.addParamAttribute(DT->getContext(), 0, Attribute::AttrKind::NonNull);
AL = AL.addParamAttribute(DT->getContext(), 1,
Attribute::AttrKind::WriteOnly);
AL = AL.addParamAttribute(DT->getContext(), 1,
Attribute::AttrKind::NoCapture);
AL = addFunctionNoCapture(DT->getContext(), AL, 1);
AL =
AL.addParamAttribute(DT->getContext(), 1, Attribute::AttrKind::NoAlias);
AL =
Expand Down Expand Up @@ -208,11 +206,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
auto alloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(rankTy);
AttributeList AL;
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::ReadOnly);
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoCapture);
AL = addFunctionNoCapture(context, AL, 0);
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoAlias);
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NonNull);
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::WriteOnly);
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoCapture);
AL = addFunctionNoCapture(context, AL, 1);
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoAlias);
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NonNull);
AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
Expand Down Expand Up @@ -241,11 +239,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
auto alloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(rankTy);
AttributeList AL;
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::ReadOnly);
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoCapture);
AL = addFunctionNoCapture(context, AL, 0);
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoAlias);
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NonNull);
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::WriteOnly);
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoCapture);
AL = addFunctionNoCapture(context, AL, 1);
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoAlias);
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NonNull);
AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/CacheUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ std::pair<PHINode *, Instruction *> FindCanonicalIV(Loop *L, Type *Ty) {
continue;
if (!Inc)
continue;
if (Inc != Header->getFirstNonPHIOrDbg())
Inc->moveBefore(Header->getFirstNonPHIOrDbg());
if (Inc != getFirstNonPHIOrDbg(Header))
Inc->moveBefore(getFirstNonPHIOrDbg(Header));
return std::make_pair(PN, Inc);
}
llvm::errs() << *Header << "\n";
Expand Down
44 changes: 22 additions & 22 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
if (F.getName() == "fprintf") {
for (auto &arg : F.args()) {
if (arg.getType()->isPointerTy()) {
arg.addAttr(Attribute::NoCapture);
addFunctionNoCapture(&F, arg.getArgNo());
changed = true;
}
}
Expand All @@ -148,7 +148,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
for (auto &arg : F.args()) {
if (arg.getType()->isPointerTy()) {
arg.addAttr(Attribute::ReadNone);
arg.addAttr(Attribute::NoCapture);
addFunctionNoCapture(&F, arg.getArgNo());
}
}
}
Expand All @@ -168,7 +168,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
F.addFnAttr(Attribute::NoSync);
for (int i = 0; i < 2; i++)
if (F.getFunctionType()->getParamType(i)->isPointerTy()) {
F.addParamAttr(i, Attribute::NoCapture);
addFunctionNoCapture(&F, i);
F.addParamAttr(i, Attribute::WriteOnly);
}
}
Expand All @@ -192,7 +192,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
F.addFnAttr(Attribute::NoSync);
F.addParamAttr(0, Attribute::WriteOnly);
if (F.getFunctionType()->getParamType(2)->isPointerTy()) {
F.addParamAttr(2, Attribute::NoCapture);
addFunctionNoCapture(&F, 2);
F.addParamAttr(2, Attribute::WriteOnly);
}
F.addParamAttr(6, Attribute::WriteOnly);
Expand All @@ -211,7 +211,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
F.addFnAttr(Attribute::NoSync);
F.addParamAttr(0, Attribute::ReadOnly);
if (F.getFunctionType()->getParamType(2)->isPointerTy()) {
F.addParamAttr(2, Attribute::NoCapture);
addFunctionNoCapture(&F, 2);
F.addParamAttr(2, Attribute::ReadOnly);
}
F.addParamAttr(6, Attribute::WriteOnly);
Expand All @@ -231,12 +231,12 @@ bool attributeKnownFunctions(llvm::Function &F) {
F.addFnAttr(Attribute::NoSync);

if (F.getFunctionType()->getParamType(0)->isPointerTy()) {
F.addParamAttr(0, Attribute::NoCapture);
addFunctionNoCapture(&F, 0);
F.addParamAttr(0, Attribute::ReadOnly);
}
if (F.getFunctionType()->getParamType(1)->isPointerTy()) {
F.addParamAttr(1, Attribute::WriteOnly);
F.addParamAttr(1, Attribute::NoCapture);
addFunctionNoCapture(&F, 1);
}
}
if (F.getName() == "MPI_Wait" || F.getName() == "PMPI_Wait") {
Expand All @@ -246,9 +246,9 @@ bool attributeKnownFunctions(llvm::Function &F) {
F.addFnAttr(Attribute::WillReturn);
F.addFnAttr(Attribute::NoFree);
F.addFnAttr(Attribute::NoSync);
F.addParamAttr(0, Attribute::NoCapture);
addFunctionNoCapture(&F, 0);
F.addParamAttr(1, Attribute::WriteOnly);
F.addParamAttr(1, Attribute::NoCapture);
addFunctionNoCapture(&F, 1);
}
if (F.getName() == "MPI_Waitall" || F.getName() == "PMPI_Waitall") {
changed = true;
Expand All @@ -257,9 +257,9 @@ bool attributeKnownFunctions(llvm::Function &F) {
F.addFnAttr(Attribute::WillReturn);
F.addFnAttr(Attribute::NoFree);
F.addFnAttr(Attribute::NoSync);
F.addParamAttr(1, Attribute::NoCapture);
addFunctionNoCapture(&F, 1);
F.addParamAttr(2, Attribute::WriteOnly);
F.addParamAttr(2, Attribute::NoCapture);
addFunctionNoCapture(&F, 2);
}
// Map of MPI function name to the arg index of its type argument
std::map<std::string, int> MPI_TYPE_ARGS = {
Expand Down Expand Up @@ -2347,7 +2347,7 @@ class EnzymeBase {
for (size_t i = 0; i < num_args; ++i) {
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadNone);
CI->addParamAttr(i, Attribute::NoCapture);
addCallSiteNoCapture(CI, i);
}
}
}
Expand All @@ -2361,7 +2361,7 @@ class EnzymeBase {
for (size_t i = 0; i < num_args; ++i) {
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadNone);
CI->addParamAttr(i, Attribute::NoCapture);
addCallSiteNoCapture(CI, i);
}
}
}
Expand All @@ -2375,7 +2375,7 @@ class EnzymeBase {
for (size_t i = 0; i < num_args; ++i) {
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadNone);
CI->addParamAttr(i, Attribute::NoCapture);
addCallSiteNoCapture(CI, i);
}
}
}
Expand All @@ -2389,7 +2389,7 @@ class EnzymeBase {
for (size_t i = 0; i < num_args; ++i) {
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadNone);
CI->addParamAttr(i, Attribute::NoCapture);
addCallSiteNoCapture(CI, i);
}
}
}
Expand Down Expand Up @@ -2439,9 +2439,9 @@ class EnzymeBase {
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
#endif
CI->addParamAttr(1, Attribute::ReadOnly);
CI->addParamAttr(1, Attribute::NoCapture);
addCallSiteNoCapture(CI, 1);
CI->addParamAttr(3, Attribute::ReadOnly);
CI->addParamAttr(3, Attribute::NoCapture);
addCallSiteNoCapture(CI, 3);
}
if (Fn->getName() == "frexp" || Fn->getName() == "frexpf" ||
Fn->getName() == "frexpl") {
Expand Down Expand Up @@ -2502,7 +2502,7 @@ class EnzymeBase {
for (size_t i : {0, 1}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::NoCapture);
addCallSiteNoCapture(CI, i);
}
}
}
Expand All @@ -2527,7 +2527,7 @@ class EnzymeBase {
for (size_t i : {0, 2}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::NoCapture);
addCallSiteNoCapture(CI, i);
}
}
}
Expand All @@ -2553,7 +2553,7 @@ class EnzymeBase {
for (size_t i : {0, 1, 2, 3}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::NoCapture);
addCallSiteNoCapture(CI, i);
}
}
}
Expand All @@ -2579,7 +2579,7 @@ class EnzymeBase {
for (size_t i : {0}) {
if (i < num_args &&
CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::NoCapture);
addCallSiteNoCapture(CI, i);
}
}
}
Expand All @@ -2601,7 +2601,7 @@ class EnzymeBase {
for (size_t i = 0; i < num_args; ++i) {
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
CI->addParamAttr(i, Attribute::ReadOnly);
CI->addParamAttr(i, Attribute::NoCapture);
addCallSiteNoCapture(CI, i);
}
}
}
Expand Down
12 changes: 10 additions & 2 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2713,9 +2713,17 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
auto i = nf->arg_begin(), j = NewF->arg_begin();
while (i != nf->arg_end()) {
VMap[i] = j;
if (nf->hasParamAttribute(attrIndex, Attribute::NoCapture)) {
NewF->addParamAttr(attrIndex, Attribute::NoCapture);
#if LLVM_VERSION_MAJOR > 20
if (nf->hasParamAttribute(attrIndex, Attribute::Captures)) {
NewF->addParamAttr(attrIndex,
nf->getParamAttribute(attrIndex, Attribute::Captures));
}
#else
if (nf->hasParamAttribute(attrIndex, Attribute::NoCaptures)) {
NewF->addParamAttr(
attrIndex, nf->getParamAttribute(attrIndex, Attribute::NoCapture));
}
#endif
if (nf->hasParamAttribute(attrIndex, Attribute::NoAlias)) {
NewF->addParamAttr(attrIndex, Attribute::NoAlias);
}
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T,
AttributeList list;
list = list.addFnAttribute(NewF->getContext(), Attribute::ReadOnly);
list = list.addParamAttribute(NewF->getContext(), 0, Attribute::ReadNone);
list = list.addParamAttribute(NewF->getContext(), 0, Attribute::NoCapture);
addFunctionNoCapture(NewF->getContext(), list, 0);
auto allocSize = NewF->getParent()->getOrInsertFunction(
allocName,
FunctionType::get(
Expand Down Expand Up @@ -1109,7 +1109,7 @@ static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
B.SetInsertPoint(Bound->getNextNode());
}
B.CreateStore(B.CreateLoad(AI2->getAllocatedType(), AI2), AI);
Bound->addParamAttr(i, Attribute::NoCapture);
addCallSiteNoCapture(Bound, i);
}
}
PreservedAnalyses PA;
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/TraceGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void TraceGenerator::visitFunction(Function &F) {
return;

auto fn = tutils->newFunc;
auto entry = fn->getEntryBlock().getFirstNonPHIOrDbgOrLifetime();
auto entry = getFirstNonPHIOrDbgOrLifetime(&fn->getEntryBlock());

while (isa<AllocaInst>(entry) && entry->getNextNode()) {
entry = entry->getNextNode();
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/TraceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ DynamicTraceInterface::DynamicTraceInterface(Value *dynamicInterface,
assert(dynamicInterface);

auto &M = *F->getParent();
IRBuilder<> Builder(F->getEntryBlock().getFirstNonPHIOrDbg());
IRBuilder<> Builder(getFirstNonPHIOrDbg(&F->getEntryBlock()));

getTraceFunction = MaterializeInterfaceFunction(
Builder, dynamicInterface, getTraceTy(), 0, M, "get_trace");
Expand Down
32 changes: 14 additions & 18 deletions enzyme/Enzyme/TraceUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,8 @@ TraceUtils::ValueToVoidPtrAndSize(IRBuilder<> &Builder, Value *val,
Builder.CreateIntToPtr(cast, getInt8PtrTy(cast->getContext()));
return {retval, ConstantInt::get(size_type, valsize / 8)};
} else {
auto insertPoint = Builder.GetInsertBlock()
->getParent()
->getEntryBlock()
.getFirstNonPHIOrDbgOrLifetime();
IRBuilder<> AllocaBuilder(insertPoint);
IRBuilder<> AllocaBuilder(getFirstNonPHIOrDbgOrLifetime(
&Builder.GetInsertBlock()->getParent()->getEntryBlock()));
auto alloca = AllocaBuilder.CreateAlloca(val->getType(), nullptr,
val->getName() + ".ptr");
Builder.CreateStore(val, alloca);
Expand Down Expand Up @@ -248,7 +245,8 @@ CallInst *TraceUtils::InsertChoice(IRBuilder<> &Builder, Value *address,
auto call = Builder.CreateCall(interface->insertChoiceTy(),
interface->insertChoice(Builder), args);
call->addParamAttr(1, Attribute::ReadOnly);
call->addParamAttr(1, Attribute::NoCapture);

addCallSiteNoCapture(call, 1);
return call;
}

Expand All @@ -259,7 +257,7 @@ CallInst *TraceUtils::InsertCall(IRBuilder<> &Builder, Value *address,
auto call = Builder.CreateCall(interface->insertCallTy(),
interface->insertCall(Builder), args);
call->addParamAttr(1, Attribute::ReadOnly);
call->addParamAttr(1, Attribute::NoCapture);
addCallSiteNoCapture(call, 1);
#if LLVM_VERSION_MAJOR >= 14
call->addAttributeAtIndex(
AttributeList::FunctionIndex,
Expand All @@ -283,7 +281,7 @@ CallInst *TraceUtils::InsertArgument(IRBuilder<> &Builder, Value *name,
auto call = Builder.CreateCall(interface->insertArgumentTy(),
interface->insertArgument(Builder), args);
call->addParamAttr(1, Attribute::ReadOnly);
call->addParamAttr(1, Attribute::NoCapture);
addCallSiteNoCapture(call, 1);
return call;
}

Expand Down Expand Up @@ -322,7 +320,7 @@ CallInst *TraceUtils::InsertChoiceGradient(IRBuilder<> &Builder,

auto call = Builder.CreateCall(interface_type, interface_function, args);
call->addParamAttr(1, Attribute::ReadOnly);
call->addParamAttr(1, Attribute::NoCapture);
addCallSiteNoCapture(call, 1);
return call;
}

Expand All @@ -339,7 +337,7 @@ CallInst *TraceUtils::InsertArgumentGradient(IRBuilder<> &Builder,

auto call = Builder.CreateCall(interface_type, interface_function, args);
call->addParamAttr(1, Attribute::ReadOnly);
call->addParamAttr(1, Attribute::NoCapture);
addCallSiteNoCapture(call, 1);
return call;
}

Expand All @@ -352,16 +350,14 @@ CallInst *TraceUtils::GetTrace(IRBuilder<> &Builder, Value *address,
auto call = Builder.CreateCall(interface->getTraceTy(),
interface->getTrace(Builder), args, Name);
call->addParamAttr(1, Attribute::ReadOnly);
call->addParamAttr(1, Attribute::NoCapture);
addCallSiteNoCapture(call, 1);
return call;
}

Instruction *TraceUtils::GetChoice(IRBuilder<> &Builder, Value *address,
Type *choiceType, const Twine &Name) {
IRBuilder<> AllocaBuilder(Builder.GetInsertBlock()
->getParent()
->getEntryBlock()
.getFirstNonPHIOrDbgOrLifetime());
IRBuilder<> AllocaBuilder(getFirstNonPHIOrDbgOrLifetime(
&Builder.GetInsertBlock()->getParent()->getEntryBlock()));
AllocaInst *store_dest =
AllocaBuilder.CreateAlloca(choiceType, nullptr, Name + ".ptr");
auto preallocated_size = choiceType->getPrimitiveSizeInBits() / 8;
Expand All @@ -385,7 +381,7 @@ Instruction *TraceUtils::GetChoice(IRBuilder<> &Builder, Value *address,
Attribute::get(call->getContext(), "enzyme_inactive"));
#endif
call->addParamAttr(1, Attribute::ReadOnly);
call->addParamAttr(1, Attribute::NoCapture);
addCallSiteNoCapture(call, 1);
return Builder.CreateLoad(choiceType, store_dest, "from.trace." + Name);
}

Expand All @@ -396,7 +392,7 @@ Instruction *TraceUtils::HasChoice(IRBuilder<> &Builder, Value *address,
auto call = Builder.CreateCall(interface->hasChoiceTy(),
interface->hasChoice(Builder), args, Name);
call->addParamAttr(1, Attribute::ReadOnly);
call->addParamAttr(1, Attribute::NoCapture);
addCallSiteNoCapture(call, 1);
return call;
}

Expand All @@ -407,7 +403,7 @@ Instruction *TraceUtils::HasCall(IRBuilder<> &Builder, Value *address,
auto call = Builder.CreateCall(interface->hasCallTy(),
interface->hasCall(Builder), args, Name);
call->addParamAttr(1, Attribute::ReadOnly);
call->addParamAttr(1, Attribute::NoCapture);
addCallSiteNoCapture(call, 1);
return call;
}

Expand Down
Loading

0 comments on commit 73e48de

Please sign in to comment.