@@ -101,8 +101,7 @@ class Enzyme : public ModulePass {
101
101
}
102
102
103
103
// / Return whether successful
104
- template <typename T>
105
- bool HandleAutoDiff (T *CI, TargetLibraryInfo &TLI, bool PostOpt,
104
+ bool HandleAutoDiff (CallInst *CI, TargetLibraryInfo &TLI, bool PostOpt,
106
105
bool fwdMode) {
107
106
108
107
Value *fn = CI->getArgOperand (0 );
@@ -575,9 +574,65 @@ class Enzyme : public ModulePass {
575
574
576
575
bool Changed = false ;
577
576
577
+ for (BasicBlock &BB : F)
578
+ if (InvokeInst *II = dyn_cast<InvokeInst>(BB.getTerminator ())) {
579
+
580
+ Function *Fn = II->getCalledFunction ();
581
+
582
+ #if LLVM_VERSION_MAJOR >= 11
583
+ if (auto castinst = dyn_cast<ConstantExpr>(II->getCalledOperand ()))
584
+ #else
585
+ if (auto castinst = dyn_cast<ConstantExpr>(II->getCalledValue ()))
586
+ #endif
587
+ {
588
+ if (castinst->isCast ())
589
+ if (auto fn = dyn_cast<Function>(castinst->getOperand (0 )))
590
+ Fn = fn;
591
+ }
592
+ if (!Fn)
593
+ continue ;
594
+
595
+ if (!(Fn->getName () == " __enzyme_float" ||
596
+ Fn->getName () == " __enzyme_double" ||
597
+ Fn->getName () == " __enzyme_integer" ||
598
+ Fn->getName () == " __enzyme_pointer" ||
599
+ Fn->getName ().contains (" __enzyme_call_inactive" ) ||
600
+ Fn->getName ().contains (" __enzyme_autodiff" ) ||
601
+ Fn->getName ().contains (" __enzyme_fwddiff" )))
602
+ continue ;
603
+
604
+ SmallVector<Value *, 16 > CallArgs (II->arg_begin (), II->arg_end ());
605
+ SmallVector<OperandBundleDef, 1 > OpBundles;
606
+ II->getOperandBundlesAsDefs (OpBundles);
607
+ // Insert a normal call instruction...
608
+ #if LLVM_VERSION_MAJOR >= 8
609
+ CallInst *NewCall =
610
+ CallInst::Create (II->getFunctionType (), II->getCalledOperand (),
611
+ CallArgs, OpBundles, " " , II);
612
+ #else
613
+ CallInst *NewCall =
614
+ CallInst::Create (II->getFunctionType (), II->getCalledValue (),
615
+ CallArgs, OpBundles, " " , II);
616
+ #endif
617
+ NewCall->takeName (II);
618
+ NewCall->setCallingConv (II->getCallingConv ());
619
+ NewCall->setAttributes (II->getAttributes ());
620
+ NewCall->setDebugLoc (II->getDebugLoc ());
621
+ II->replaceAllUsesWith (NewCall);
622
+
623
+ // Insert an unconditional branch to the normal destination.
624
+ BranchInst::Create (II->getNormalDest (), II);
625
+
626
+ // Remove any PHI node entries from the exception destination.
627
+ II->getUnwindDest ()->removePredecessor (&BB);
628
+
629
+ // Remove the invoke instruction now.
630
+ BB.getInstList ().erase (II);
631
+ Changed = true ;
632
+ }
633
+
578
634
std::set<CallInst *> toLowerAuto;
579
635
std::set<CallInst *> toLowerFwd;
580
- std::set<InvokeInst *> toLowerI;
581
636
std::set<CallInst *> InactiveCalls;
582
637
retry:;
583
638
for (BasicBlock &BB : F) {
@@ -752,15 +807,9 @@ class Enzyme : public ModulePass {
752
807
}
753
808
}
754
809
755
- bool autoDiff = Fn && (Fn->getName () == " __enzyme_autodiff" ||
756
- Fn->getName () == " enzyme_autodiff_" ||
757
- Fn->getName ().startswith (" __enzyme_autodiff" ) ||
758
- Fn->getName ().contains (" __enzyme_autodiff" ));
810
+ bool autoDiff = Fn && Fn->getName ().contains (" __enzyme_autodiff" );
759
811
760
- bool fwdDiff = Fn && (Fn->getName () == " __enzyme_fwddiff" ||
761
- Fn->getName () == " enzyme_fwddiff_" ||
762
- Fn->getName ().startswith (" __enzyme_fwddiff" ) ||
763
- Fn->getName ().contains (" __enzyme_fwddiff" ));
812
+ bool fwdDiff = Fn && Fn->getName ().contains (" __enzyme_fwddiff" );
764
813
765
814
if (autoDiff || fwdDiff) {
766
815
if (autoDiff) {
@@ -845,13 +894,6 @@ class Enzyme : public ModulePass {
845
894
break ;
846
895
}
847
896
848
- for (auto CI : toLowerI) {
849
- successful &= HandleAutoDiff (CI, TLI, PostOpt, /* fwdMode*/ false );
850
- Changed = true ;
851
- if (!successful)
852
- break ;
853
- }
854
-
855
897
if (Changed) {
856
898
// TODO consider enabling when attributor does not delete
857
899
// dead internal functions, which invalidates Enzyme's cache
0 commit comments