Skip to content

Commit 75fee65

Browse files
committed
Fix internal byval usage
1 parent b690fac commit 75fee65

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4221,16 +4221,29 @@ class AdjointGenerator
42214221
std::vector<DIFFE_TYPE> argsInverted;
42224222
std::vector<Instruction *> postCreate;
42234223
std::vector<Instruction *> userReplace;
4224+
std::map<int, Type *> preByVal;
4225+
std::map<int, Type *> gradByVal;
42244226

42254227
for (unsigned i = 0; i < orig->getNumArgOperands(); ++i) {
42264228

42274229
auto argi = gutils->getNewFromOriginal(orig->getArgOperand(i));
42284230

4231+
#if LLVM_VERSION_MAJOR >= 9
4232+
if (orig->isByValArgument(i)) {
4233+
preByVal[pre_args.size()] = orig->getParamByValType(i);
4234+
}
4235+
#endif
4236+
42294237
pre_args.push_back(argi);
42304238

42314239
if (Mode != DerivativeMode::ReverseModePrimal) {
42324240
IRBuilder<> Builder2(call.getParent());
42334241
getReverseBuilder(Builder2);
4242+
#if LLVM_VERSION_MAJOR >= 9
4243+
if (orig->isByValArgument(i)) {
4244+
gradByVal[args.size()] = orig->getParamByValType(i);
4245+
}
4246+
#endif
42344247
args.push_back(lookup(argi, Builder2));
42354248
}
42364249

@@ -4467,6 +4480,13 @@ class AdjointGenerator
44674480
augmentcall->setCallingConv(orig->getCallingConv());
44684481
augmentcall->setDebugLoc(
44694482
gutils->getNewFromOriginal(orig->getDebugLoc()));
4483+
#if LLVM_VERSION_MAJOR >= 9
4484+
for (auto pair : preByVal) {
4485+
augmentcall->addParamAttr(
4486+
pair.first, Attribute::getWithByValType(augmentcall->getContext(),
4487+
pair.second));
4488+
}
4489+
#endif
44704490

44714491
if (!augmentcall->getType()->isVoidTy())
44724492
augmentcall->setName(orig->getName() + "_augmented");
@@ -4774,6 +4794,12 @@ class AdjointGenerator
47744794
CallInst *diffes = Builder2.CreateCall(FT, newcalled, args);
47754795
diffes->setCallingConv(orig->getCallingConv());
47764796
diffes->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc()));
4797+
#if LLVM_VERSION_MAJOR >= 9
4798+
for (auto pair : gradByVal) {
4799+
diffes->addParamAttr(pair.first, Attribute::getWithByValType(
4800+
diffes->getContext(), pair.second));
4801+
}
4802+
#endif
47774803

47784804
unsigned structidx = retUsed ? 1 : 0;
47794805
if (subdretptr)

0 commit comments

Comments
 (0)