Skip to content

Commit f0f64db

Browse files
authored
Add forward mode fadd and free debug (rust-lang#674)
1 parent bb70046 commit f0f64db

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,56 @@ class AdjointGenerator
778778
}
779779

780780
void visitAtomicRMWInst(llvm::AtomicRMWInst &I) {
781+
if (Mode == DerivativeMode::ForwardMode) {
782+
IRBuilder<> BuilderZ(&I);
783+
getForwardBuilder(BuilderZ);
784+
switch (I.getOperation()) {
785+
case AtomicRMWInst::FAdd:
786+
case AtomicRMWInst::FSub: {
787+
auto rule = [&](Value *ptr, Value *dif) -> Value * {
788+
if (!gutils->isConstantInstruction(&I)) {
789+
assert(ptr);
790+
AtomicRMWInst *rmw = nullptr;
791+
#if LLVM_VERSION_MAJOR >= 13
792+
rmw = BuilderZ.CreateAtomicRMW(I.getOperation(), ptr, dif,
793+
I.getAlign(), I.getOrdering(),
794+
I.getSyncScopeID());
795+
#elif LLVM_VERSION_MAJOR >= 11
796+
rmw = BuilderZ.CreateAtomicRMW(I.getOperation(), ptr, dif,
797+
I.getOrdering(), I.getSyncScopeID());
798+
rmw->setAlignment(I.getAlign());
799+
#else
800+
rmw = BuilderZ.CreateAtomicRMW(
801+
I.getOperation(), ptr, dif, I.getOrdering(),
802+
I.getSyncScopeID());
803+
#endif
804+
rmw->setVolatile(I.isVolatile());
805+
if (gutils->isConstantValue(&I))
806+
return Constant::getNullValue(dif->getType());
807+
else
808+
return rmw;
809+
} else {
810+
assert(gutils->isConstantValue(&I));
811+
return Constant::getNullValue(dif->getType());
812+
}
813+
};
814+
815+
Value *diff = applyChainRule(
816+
I.getType(), BuilderZ, rule,
817+
gutils->isConstantValue(I.getPointerOperand())
818+
? nullptr
819+
: gutils->invertPointerM(I.getPointerOperand(), BuilderZ),
820+
gutils->isConstantValue(I.getValOperand())
821+
? Constant::getNullValue(I.getType())
822+
: gutils->invertPointerM(I.getValOperand(), BuilderZ));
823+
if (!gutils->isConstantValue(&I))
824+
setDiffe(&I, diff, BuilderZ);
825+
return;
826+
}
827+
default:
828+
break;
829+
}
830+
}
781831
if (!gutils->isConstantInstruction(&I) || !gutils->isConstantValue(&I)) {
782832
TR.dump();
783833
llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n";
@@ -11083,7 +11133,8 @@ class AdjointGenerator
1108311133
auto rule = [&args](Value *tofree) { args.push_back(tofree); };
1108411134
applyChainRule(Builder2, rule, tofree);
1108511135

11086-
Builder2.CreateCall(free->getFunctionType(), free, args);
11136+
auto frees = Builder2.CreateCall(free->getFunctionType(), free, args);
11137+
frees->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc()));
1108711138

1108811139
return;
1108911140
}

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3821,7 +3821,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
38213821
}
38223822
}
38233823

3824-
if (auto arg = dyn_cast<ConstantExpr>(oval)) {
3824+
if (isa<ConstantExpr>(oval)) {
38253825
auto rule = [&oval]() { return oval; };
38263826
return applyChainRule(oval->getType(), BuilderM, rule);
38273827
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi
2+
3+
; Function Attrs: norecurse nounwind readonly uwtable
4+
define dso_local double @sum(double* nocapture %n, double %x) #0 {
5+
entry:
6+
%res = atomicrmw fadd double* %n, double %x monotonic
7+
ret double %res
8+
}
9+
10+
; Function Attrs: nounwind uwtable
11+
define dso_local void @dsum(double* %x, double* %xp, double %n) local_unnamed_addr #1 {
12+
entry:
13+
%0 = tail call double (double (double*, double)*, ...) @__enzyme_fwddiff(double (double*, double)* nonnull @sum, double* %x, double* %xp, double %n, double 1.000000e+00)
14+
ret void
15+
}
16+
17+
; Function Attrs: nounwind
18+
declare double @__enzyme_fwddiff(double (double*, double)*, ...) #2
19+
20+
attributes #0 = { norecurse nounwind readonly uwtable }
21+
attributes #1 = { nounwind uwtable }
22+
attributes #2 = { nounwind }
23+
24+
; CHECK: define internal double @fwddiffesum(double* nocapture %n, double* nocapture %"n'", double %x, double %"x'")
25+
; CHECK-NEXT: entry:
26+
; CHECK-NEXT: %res = atomicrmw fadd double* %n, double %x monotonic
27+
; CHECK-NEXT: %0 = atomicrmw fadd double* %"n'", double %"x'" monotonic
28+
; CHECK-NEXT: ret double %0
29+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)