@@ -585,8 +585,9 @@ class AdjointGenerator
585
585
// the instruction if the value is a potential pointer. This may not be
586
586
// caught by type analysis is the result does not have a known type.
587
587
if (!gutils->isConstantInstruction (&I)) {
588
- Type *isfloat =
589
- type->isFPOrFPVectorTy () ? type->getScalarType () : nullptr ;
588
+ Type *isfloat = I.getType ()->isFPOrFPVectorTy ()
589
+ ? I.getType ()->getScalarType ()
590
+ : nullptr ;
590
591
if (!isfloat && type->isIntOrIntVectorTy ()) {
591
592
auto LoadSize = DL.getTypeSizeInBits (type) / 8 ;
592
593
ConcreteType vd = BaseType::Unknown;
@@ -610,40 +611,48 @@ class AdjointGenerator
610
611
getForwardBuilder (Builder2);
611
612
612
613
if (!gutils->isConstantValue (&I)) {
614
+ Value *ip = gutils->invertPointerM (I.getOperand (0 ), Builder2);
615
+
613
616
Value *diff;
614
617
if (!mask) {
618
+
619
+ auto rule = [&](Value *ip) {
615
620
#if LLVM_VERSION_MAJOR > 7
616
- auto LI = Builder2.CreateLoad (
617
- cast<PointerType>(I.getOperand (0 )->getType ())
618
- ->getElementType (),
619
- gutils->invertPointerM (I.getOperand (0 ), Builder2));
621
+ auto LI = Builder2.CreateLoad (I.getType (), ip);
620
622
#else
621
- auto LI = Builder2.CreateLoad (
622
- gutils->invertPointerM (I.getOperand (0 ), Builder2));
623
+ auto LI = Builder2.CreateLoad (ip);
623
624
#endif
624
- if (alignment)
625
+ if (alignment)
625
626
#if LLVM_VERSION_MAJOR >= 10
626
- LI->setAlignment (*alignment);
627
+ LI->setAlignment (*alignment);
627
628
#else
628
- LI->setAlignment (alignment);
629
+ LI->setAlignment (alignment);
629
630
#endif
630
- diff = LI;
631
+ return LI;
632
+ };
633
+
634
+ diff = applyChainRule (I.getType (), Builder2, rule, ip);
635
+
631
636
} else {
632
- Type *tys[] = {I.getType (), I.getOperand (0 )->getType ()};
633
- auto F = Intrinsic::getDeclaration (gutils->oldFunc ->getParent (),
634
- Intrinsic::masked_load, tys);
637
+ auto mi = diffe (orig_maskInit, Builder2);
638
+
639
+ auto rule = [&](Value *ip, Value *mi) {
640
+ Type *tys[] = {I.getType (), I.getOperand (0 )->getType ()};
641
+ auto F = Intrinsic::getDeclaration (gutils->oldFunc ->getParent (),
642
+ Intrinsic::masked_load, tys);
635
643
#if LLVM_VERSION_MAJOR >= 10
636
- Value *alignv =
637
- ConstantInt::get (Type::getInt32Ty (mask->getContext ()),
638
- alignment ? alignment->value () : 0 );
644
+ Value *alignv =
645
+ ConstantInt::get (Type::getInt32Ty (mask->getContext ()),
646
+ alignment ? alignment->value () : 0 );
639
647
#else
640
- Value *alignv = ConstantInt::get (
641
- Type::getInt32Ty (mask->getContext ()), alignment);
648
+ Value *alignv = ConstantInt::get (
649
+ Type::getInt32Ty (mask->getContext ()), alignment);
642
650
#endif
643
- Value *args[] = {
644
- gutils->invertPointerM (I.getOperand (0 ), Builder2), alignv,
645
- mask, diffe (orig_maskInit, Builder2)};
646
- diff = Builder2.CreateCall (F, args);
651
+ Value *args[] = {ip, alignv, mask, mi};
652
+ return Builder2.CreateCall (F, args);
653
+ };
654
+
655
+ diff = applyChainRule (I.getType (), Builder2, rule, ip, mi);
647
656
}
648
657
setDiffe (&I, diff, Builder2);
649
658
}
@@ -869,10 +878,13 @@ class AdjointGenerator
869
878
IRBuilder<> Builder2 (&I);
870
879
getForwardBuilder (Builder2);
871
880
872
- Value *diff = constantval ? Constant::getNullValue (valType)
881
+ Type *diffeTy = gutils->getShadowType (valType);
882
+
883
+ Value *diff = constantval ? Constant::getNullValue (diffeTy)
873
884
: diffe (orig_val, Builder2);
874
885
gutils->setPtrDiffe (orig_ptr, diff, Builder2, align, isVolatile,
875
886
ordering, syncScope, mask);
887
+
876
888
break ;
877
889
}
878
890
}
@@ -889,6 +901,14 @@ class AdjointGenerator
889
901
890
902
if (constantval) {
891
903
valueop = val;
904
+ if (gutils->getWidth () > 1 ) {
905
+ Value *array =
906
+ UndefValue::get (gutils->getShadowType (val->getType ()));
907
+ for (unsigned i = 0 ; i < gutils->getWidth (); ++i) {
908
+ array = storeBuilder.CreateInsertValue (array, val, {i});
909
+ }
910
+ valueop = array;
911
+ }
892
912
} else {
893
913
valueop = gutils->invertPointerM (orig_val, storeBuilder);
894
914
}
@@ -915,37 +935,38 @@ class AdjointGenerator
915
935
}
916
936
case DerivativeMode::ForwardModeSplit:
917
937
case DerivativeMode::ForwardMode: {
918
- break ;
919
- }
920
- }
938
+ BasicBlock *oBB = phi.getParent ();
939
+ BasicBlock *nBB = gutils->getNewFromOriginal (oBB);
921
940
922
- BasicBlock *oBB = phi. getParent ( );
923
- BasicBlock *nBB = gutils-> getNewFromOriginal (oBB );
941
+ IRBuilder<> diffeBuilder (nBB-> getFirstNonPHI () );
942
+ diffeBuilder. setFastMathFlags ( getFast () );
924
943
925
- IRBuilder<> diffeBuilder (nBB-> getFirstNonPHI () );
926
- diffeBuilder. setFastMathFlags ( getFast () );
944
+ IRBuilder<> phiBuilder (&phi );
945
+ getForwardBuilder (phiBuilder );
927
946
928
- IRBuilder<> phiBuilder (&phi);
929
- getForwardBuilder (phiBuilder);
947
+ Type *diffeType = gutils->getShadowType (phi.getType ());
930
948
931
- auto newPhi = phiBuilder.CreatePHI (phi. getType () , 1 , phi.getName () + " '" );
932
- for (unsigned int i = 0 ; i < phi.getNumIncomingValues (); ++i) {
933
- auto val = phi.getIncomingValue (i);
934
- auto block = phi.getIncomingBlock (i);
949
+ auto newPhi = phiBuilder.CreatePHI (diffeType , 1 , phi.getName () + " '" );
950
+ for (unsigned int i = 0 ; i < phi.getNumIncomingValues (); ++i) {
951
+ auto val = phi.getIncomingValue (i);
952
+ auto block = phi.getIncomingBlock (i);
935
953
936
- auto newBlock = gutils->getNewFromOriginal (block);
937
- IRBuilder<> pBuilder (newBlock->getTerminator ());
938
- pBuilder.setFastMathFlags (getFast ());
954
+ auto newBlock = gutils->getNewFromOriginal (block);
955
+ IRBuilder<> pBuilder (newBlock->getTerminator ());
956
+ pBuilder.setFastMathFlags (getFast ());
939
957
940
- if (gutils->isConstantValue (val)) {
941
- newPhi->addIncoming (Constant::getNullValue (val->getType ()), newBlock);
942
- } else {
943
- auto diff = diffe (val, pBuilder);
944
- newPhi->addIncoming (diff, newBlock);
958
+ if (gutils->isConstantValue (val)) {
959
+ newPhi->addIncoming (Constant::getNullValue (diffeType), newBlock);
960
+ } else {
961
+ auto diff = diffe (val, pBuilder);
962
+ newPhi->addIncoming (diff, newBlock);
963
+ }
945
964
}
946
- }
947
965
948
- setDiffe (&phi, newPhi, diffeBuilder);
966
+ setDiffe (&phi, newPhi, diffeBuilder);
967
+ return ;
968
+ }
969
+ }
949
970
}
950
971
951
972
void visitCastInst (llvm::CastInst &I) {
@@ -2589,11 +2610,26 @@ class AdjointGenerator
2589
2610
dsrc = Builder2.CreateIntToPtr (dsrc,
2590
2611
Type::getInt8PtrTy (dsrc->getContext ()));
2591
2612
2592
- auto call =
2593
- Builder2.CreateMemCpy (ddst, dstAlign, dsrc, srcAlign, new_size);
2594
- call->setAttributes (MTI.getAttributes ());
2595
- call->setTailCallKind (MTI.getTailCallKind ());
2613
+ auto rule = [&](Value *ddst, Value *dsrc) {
2614
+ CallInst *call;
2615
+ if (ID == Intrinsic::memmove) {
2616
+ call =
2617
+ Builder2.CreateMemMove (ddst, dstAlign, dsrc, srcAlign, new_size);
2618
+ } else {
2619
+ call =
2620
+ Builder2.CreateMemCpy (ddst, dstAlign, dsrc, srcAlign, new_size);
2621
+ }
2622
+ call->setAttributes (MTI.getAttributes ());
2623
+ call->setMetadata (LLVMContext::MD_tbaa,
2624
+ MTI.getMetadata (LLVMContext::MD_tbaa));
2625
+ call->setMetadata (LLVMContext::MD_tbaa_struct,
2626
+ MTI.getMetadata (LLVMContext::MD_tbaa_struct));
2627
+ call->setMetadata (LLVMContext::MD_invariant_group,
2628
+ MTI.getMetadata (LLVMContext::MD_invariant_group));
2629
+ call->setTailCallKind (MTI.getTailCallKind ());
2630
+ };
2596
2631
2632
+ applyChainRule (Builder2, rule, ddst, dsrc);
2597
2633
return ;
2598
2634
}
2599
2635
0 commit comments