@@ -4221,16 +4221,29 @@ class AdjointGenerator
4221
4221
std::vector<DIFFE_TYPE> argsInverted;
4222
4222
std::vector<Instruction *> postCreate;
4223
4223
std::vector<Instruction *> userReplace;
4224
+ std::map<int , Type *> preByVal;
4225
+ std::map<int , Type *> gradByVal;
4224
4226
4225
4227
for (unsigned i = 0 ; i < orig->getNumArgOperands (); ++i) {
4226
4228
4227
4229
auto argi = gutils->getNewFromOriginal (orig->getArgOperand (i));
4228
4230
4231
+ #if LLVM_VERSION_MAJOR >= 9
4232
+ if (orig->isByValArgument (i)) {
4233
+ preByVal[pre_args.size ()] = orig->getParamByValType (i);
4234
+ }
4235
+ #endif
4236
+
4229
4237
pre_args.push_back (argi);
4230
4238
4231
4239
if (Mode != DerivativeMode::ReverseModePrimal) {
4232
4240
IRBuilder<> Builder2 (call.getParent ());
4233
4241
getReverseBuilder (Builder2);
4242
+ #if LLVM_VERSION_MAJOR >= 9
4243
+ if (orig->isByValArgument (i)) {
4244
+ gradByVal[args.size ()] = orig->getParamByValType (i);
4245
+ }
4246
+ #endif
4234
4247
args.push_back (lookup (argi, Builder2));
4235
4248
}
4236
4249
@@ -4467,6 +4480,13 @@ class AdjointGenerator
4467
4480
augmentcall->setCallingConv (orig->getCallingConv ());
4468
4481
augmentcall->setDebugLoc (
4469
4482
gutils->getNewFromOriginal (orig->getDebugLoc ()));
4483
+ #if LLVM_VERSION_MAJOR >= 9
4484
+ for (auto pair : preByVal) {
4485
+ augmentcall->addParamAttr (
4486
+ pair.first , Attribute::getWithByValType (augmentcall->getContext (),
4487
+ pair.second ));
4488
+ }
4489
+ #endif
4470
4490
4471
4491
if (!augmentcall->getType ()->isVoidTy ())
4472
4492
augmentcall->setName (orig->getName () + " _augmented" );
@@ -4774,6 +4794,12 @@ class AdjointGenerator
4774
4794
CallInst *diffes = Builder2.CreateCall (FT, newcalled, args);
4775
4795
diffes->setCallingConv (orig->getCallingConv ());
4776
4796
diffes->setDebugLoc (gutils->getNewFromOriginal (orig->getDebugLoc ()));
4797
+ #if LLVM_VERSION_MAJOR >= 9
4798
+ for (auto pair : gradByVal) {
4799
+ diffes->addParamAttr (pair.first , Attribute::getWithByValType (
4800
+ diffes->getContext (), pair.second ));
4801
+ }
4802
+ #endif
4777
4803
4778
4804
unsigned structidx = retUsed ? 1 : 0 ;
4779
4805
if (subdretptr)
0 commit comments