Skip to content

Commit 33bee0d

Browse files
committed
Handle differentiation via invoke
1 parent 81e87fc commit 33bee0d

File tree

3 files changed

+99
-19
lines changed

3 files changed

+99
-19
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ class Enzyme : public ModulePass {
101101
}
102102

103103
/// Return whether successful
104-
template <typename T>
105-
bool HandleAutoDiff(T *CI, TargetLibraryInfo &TLI, bool PostOpt,
104+
bool HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, bool PostOpt,
106105
bool fwdMode) {
107106

108107
Value *fn = CI->getArgOperand(0);
@@ -575,9 +574,65 @@ class Enzyme : public ModulePass {
575574

576575
bool Changed = false;
577576

577+
for (BasicBlock &BB : F)
578+
if (InvokeInst *II = dyn_cast<InvokeInst>(BB.getTerminator())) {
579+
580+
Function *Fn = II->getCalledFunction();
581+
582+
#if LLVM_VERSION_MAJOR >= 11
583+
if (auto castinst = dyn_cast<ConstantExpr>(II->getCalledOperand()))
584+
#else
585+
if (auto castinst = dyn_cast<ConstantExpr>(II->getCalledValue()))
586+
#endif
587+
{
588+
if (castinst->isCast())
589+
if (auto fn = dyn_cast<Function>(castinst->getOperand(0)))
590+
Fn = fn;
591+
}
592+
if (!Fn)
593+
continue;
594+
595+
if (!(Fn->getName() == "__enzyme_float" ||
596+
Fn->getName() == "__enzyme_double" ||
597+
Fn->getName() == "__enzyme_integer" ||
598+
Fn->getName() == "__enzyme_pointer" ||
599+
Fn->getName().contains("__enzyme_call_inactive") ||
600+
Fn->getName().contains("__enzyme_autodiff") ||
601+
Fn->getName().contains("__enzyme_fwddiff")))
602+
continue;
603+
604+
SmallVector<Value *, 16> CallArgs(II->arg_begin(), II->arg_end());
605+
SmallVector<OperandBundleDef, 1> OpBundles;
606+
II->getOperandBundlesAsDefs(OpBundles);
607+
// Insert a normal call instruction...
608+
#if LLVM_VERSION_MAJOR >= 8
609+
CallInst *NewCall =
610+
CallInst::Create(II->getFunctionType(), II->getCalledOperand(),
611+
CallArgs, OpBundles, "", II);
612+
#else
613+
CallInst *NewCall =
614+
CallInst::Create(II->getFunctionType(), II->getCalledValue(),
615+
CallArgs, OpBundles, "", II);
616+
#endif
617+
NewCall->takeName(II);
618+
NewCall->setCallingConv(II->getCallingConv());
619+
NewCall->setAttributes(II->getAttributes());
620+
NewCall->setDebugLoc(II->getDebugLoc());
621+
II->replaceAllUsesWith(NewCall);
622+
623+
// Insert an unconditional branch to the normal destination.
624+
BranchInst::Create(II->getNormalDest(), II);
625+
626+
// Remove any PHI node entries from the exception destination.
627+
II->getUnwindDest()->removePredecessor(&BB);
628+
629+
// Remove the invoke instruction now.
630+
BB.getInstList().erase(II);
631+
Changed = true;
632+
}
633+
578634
std::set<CallInst *> toLowerAuto;
579635
std::set<CallInst *> toLowerFwd;
580-
std::set<InvokeInst *> toLowerI;
581636
std::set<CallInst *> InactiveCalls;
582637
retry:;
583638
for (BasicBlock &BB : F) {
@@ -752,15 +807,9 @@ class Enzyme : public ModulePass {
752807
}
753808
}
754809

755-
bool autoDiff = Fn && (Fn->getName() == "__enzyme_autodiff" ||
756-
Fn->getName() == "enzyme_autodiff_" ||
757-
Fn->getName().startswith("__enzyme_autodiff") ||
758-
Fn->getName().contains("__enzyme_autodiff"));
810+
bool autoDiff = Fn && Fn->getName().contains("__enzyme_autodiff");
759811

760-
bool fwdDiff = Fn && (Fn->getName() == "__enzyme_fwddiff" ||
761-
Fn->getName() == "enzyme_fwddiff_" ||
762-
Fn->getName().startswith("__enzyme_fwddiff") ||
763-
Fn->getName().contains("__enzyme_fwddiff"));
812+
bool fwdDiff = Fn && Fn->getName().contains("__enzyme_fwddiff");
764813

765814
if (autoDiff || fwdDiff) {
766815
if (autoDiff) {
@@ -845,13 +894,6 @@ class Enzyme : public ModulePass {
845894
break;
846895
}
847896

848-
for (auto CI : toLowerI) {
849-
successful &= HandleAutoDiff(CI, TLI, PostOpt, /*fwdMode*/ false);
850-
Changed = true;
851-
if (!successful)
852-
break;
853-
}
854-
855897
if (Changed) {
856898
// TODO consider enabling when attributor does not delete
857899
// dead internal functions, which invalidates Enzyme's cache

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2736,6 +2736,13 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
27362736
shouldRecompute(LI, incoming_available, &BuilderM);
27372737
}
27382738
}
2739+
if (!inst->mayReadOrWriteMemory()) {
2740+
reduceRegister |= tryLegalRecomputeCheck &&
2741+
legalRecompute(inst, incoming_available, &BuilderM) &&
2742+
shouldRecompute(inst, incoming_available, &BuilderM);
2743+
}
2744+
if (this->isOriginalBlock(*BuilderM.GetInsertBlock()))
2745+
reduceRegister = false;
27392746
}
27402747

27412748
if (!reduceRegister) {
@@ -2928,7 +2935,6 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
29282935
}
29292936
inst = cast<Instruction>(val);
29302937
assert(prelcssaInst->getType() == inst->getType());
2931-
29322938
assert(!this->isOriginalBlock(*BuilderM.GetInsertBlock()));
29332939

29342940
// Update index and caching per lcssa
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -adce -S | FileCheck %s
2+
3+
define double @sq(double %x) {
4+
entry:
5+
%0 = fmul fast double %x, %x
6+
ret double %0
7+
}
8+
9+
declare i32 @__gxx_personality_v0(...)
10+
11+
; Function Attrs: norecurse ssp uwtable
12+
define double @caller(double %x) personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
13+
%res = invoke double (...) @_Z17__enzyme_autodiffz(double (double)* nonnull @sq, double %x)
14+
to label %eblock unwind label %cblock
15+
16+
eblock:
17+
ret double %res
18+
19+
cblock:
20+
%lp = landingpad { i8*, i32 }
21+
cleanup
22+
ret double 0.000000e+00
23+
}
24+
25+
declare double @_Z17__enzyme_autodiffz(...)
26+
27+
; CHECK: define double @caller(double %x)
28+
; CHECK-NEXT: eblock:
29+
; CHECK-NEXT: %0 = call { double } @diffesq(double %x, double 1.000000e+00)
30+
; CHECK-NEXT: %1 = extractvalue { double } %0, 0
31+
; CHECK-NEXT: ret double %1
32+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)