Skip to content

Commit cba5a2b

Browse files
committed
[AutoDiff] Make VJP applications use the correct substitution map.
If a custom `@differentiable` attribute defines a VJP and where clause requirements, VJP applications should use a substitution map involving those requirements. Note: more related cases need to be handled, such as `@differentiable` attributes with where clause requirements but no VJP. These cases will be handled later.
1 parent 99912b3 commit cba5a2b

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2441,9 +2441,16 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
24412441
newArgs.push_back(getOpValue(origArg));
24422442
assert(newArgs.size() == numVJPParams);
24432443
// Apply the VJP.
2444-
auto *vjpCall = getBuilder().createApply(ai->getLoc(), vjp,
2445-
ai->getSubstitutionMap(), newArgs,
2446-
ai->isNonThrowing());
2444+
auto substMap = ai->getSubstitutionMap();
2445+
if (auto vjpGenSig = vjpFnTy->getGenericSignature()) {
2446+
auto vjpSubstMap =
2447+
vjpGenSig->createGenericEnvironment()->getForwardingSubstitutionMap();
2448+
substMap = vjpSubstMap.subst(
2449+
[&](SubstitutableType *ty) { return Type(ty).subst(substMap); },
2450+
LookUpConformanceInModule(context.getModule().getSwiftModule()));
2451+
}
2452+
auto *vjpCall = getBuilder().createApply(ai->getLoc(), vjp, substMap,
2453+
newArgs, ai->isNonThrowing());
24472454
LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall);
24482455

24492456
// Get the VJP results (original results and pullback).

0 commit comments

Comments
 (0)