@@ -778,6 +778,56 @@ class AdjointGenerator
778
778
}
779
779
780
780
void visitAtomicRMWInst (llvm::AtomicRMWInst &I) {
781
+ if (Mode == DerivativeMode::ForwardMode) {
782
+ IRBuilder<> BuilderZ (&I);
783
+ getForwardBuilder (BuilderZ);
784
+ switch (I.getOperation ()) {
785
+ case AtomicRMWInst::FAdd:
786
+ case AtomicRMWInst::FSub: {
787
+ auto rule = [&](Value *ptr, Value *dif) -> Value * {
788
+ if (!gutils->isConstantInstruction (&I)) {
789
+ assert (ptr);
790
+ AtomicRMWInst *rmw = nullptr ;
791
+ #if LLVM_VERSION_MAJOR >= 13
792
+ rmw = BuilderZ.CreateAtomicRMW (I.getOperation (), ptr, dif,
793
+ I.getAlign (), I.getOrdering (),
794
+ I.getSyncScopeID ());
795
+ #elif LLVM_VERSION_MAJOR >= 11
796
+ rmw = BuilderZ.CreateAtomicRMW (I.getOperation (), ptr, dif,
797
+ I.getOrdering (), I.getSyncScopeID ());
798
+ rmw->setAlignment (I.getAlign ());
799
+ #else
800
+ rmw = BuilderZ.CreateAtomicRMW (
801
+ I.getOperation (), ptr, dif, I.getOrdering (),
802
+ I.getSyncScopeID ());
803
+ #endif
804
+ rmw->setVolatile (I.isVolatile ());
805
+ if (gutils->isConstantValue (&I))
806
+ return Constant::getNullValue (dif->getType ());
807
+ else
808
+ return rmw;
809
+ } else {
810
+ assert (gutils->isConstantValue (&I));
811
+ return Constant::getNullValue (dif->getType ());
812
+ }
813
+ };
814
+
815
+ Value *diff = applyChainRule (
816
+ I.getType (), BuilderZ, rule,
817
+ gutils->isConstantValue (I.getPointerOperand ())
818
+ ? nullptr
819
+ : gutils->invertPointerM (I.getPointerOperand (), BuilderZ),
820
+ gutils->isConstantValue (I.getValOperand ())
821
+ ? Constant::getNullValue (I.getType ())
822
+ : gutils->invertPointerM (I.getValOperand (), BuilderZ));
823
+ if (!gutils->isConstantValue (&I))
824
+ setDiffe (&I, diff, BuilderZ);
825
+ return ;
826
+ }
827
+ default :
828
+ break ;
829
+ }
830
+ }
781
831
if (!gutils->isConstantInstruction (&I) || !gutils->isConstantValue (&I)) {
782
832
TR.dump ();
783
833
llvm::errs () << " oldFunc: " << *gutils->newFunc << " \n " ;
@@ -11083,7 +11133,8 @@ class AdjointGenerator
11083
11133
auto rule = [&args](Value *tofree) { args.push_back (tofree); };
11084
11134
applyChainRule (Builder2, rule, tofree);
11085
11135
11086
- Builder2.CreateCall (free->getFunctionType (), free, args);
11136
+ auto frees = Builder2.CreateCall (free->getFunctionType (), free, args);
11137
+ frees->setDebugLoc (gutils->getNewFromOriginal (orig->getDebugLoc ()));
11087
11138
11088
11139
return ;
11089
11140
}
0 commit comments