@@ -1160,7 +1160,8 @@ class ADContext {
1160
1160
// / Get or create an associated function index subset thunk from
1161
1161
// / `actualIndices` to `desiredIndices` for the given associated function
1162
1162
// / value and original function operand.
1163
- SILFunction *getOrCreateSubsetParametersThunkForLinearMap (
1163
+ std::pair<SILFunction *, SubstitutionMap>
1164
+ getOrCreateSubsetParametersThunkForLinearMap (
1164
1165
SILFunction *assocFn, CanSILFunctionType linearMapType,
1165
1166
CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind,
1166
1167
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);
@@ -8105,7 +8106,7 @@ class Differentiation : public SILModuleTransform {
8105
8106
};
8106
8107
} // end anonymous namespace
8107
8108
8108
- SILFunction *
8109
+ std::pair< SILFunction *, SubstitutionMap>
8109
8110
ADContext::getOrCreateSubsetParametersThunkForLinearMap (
8110
8111
SILFunction *parentThunk, CanSILFunctionType linearMapType,
8111
8112
CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind,
@@ -8114,8 +8115,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
8114
8115
<< " Getting a subset parameters thunk for " << linearMapType
8115
8116
<< " from " << actualIndices << " to " << desiredIndices << ' \n ' );
8116
8117
8117
- SubstitutionMap interfaceSubs = parentThunk-> getForwardingSubstitutionMap () ;
8118
- GenericEnvironment *genericEnv = parentThunk-> getGenericEnvironment () ;
8118
+ SubstitutionMap interfaceSubs;
8119
+ GenericEnvironment *genericEnv = nullptr ;
8119
8120
auto thunkType = buildThunkType (
8120
8121
parentThunk, linearMapType, targetType, genericEnv, interfaceSubs,
8121
8122
/* withoutActuallyEscaping*/ true ,
@@ -8148,7 +8149,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
8148
8149
ProfileCounter (), IsThunk, IsNotDynamic);
8149
8150
8150
8151
if (!thunk->empty ())
8151
- return thunk;
8152
+ return { thunk, interfaceSubs} ;
8152
8153
8153
8154
thunk->setGenericEnvironment (genericEnv);
8154
8155
thunk->setOwnershipEliminated ();
@@ -8296,7 +8297,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
8296
8297
for (auto *alloc : reversed (localAllocations))
8297
8298
builder.createDeallocStack (loc, alloc);
8298
8299
builder.createReturn (loc, ai);
8299
- return thunk;
8300
+ return { thunk, interfaceSubs} ;
8300
8301
}
8301
8302
8302
8303
// If pullback thunk, return only the desired results and clean up the
@@ -8332,7 +8333,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
8332
8333
builder.createReturn (loc, result);
8333
8334
8334
8335
getGeneratedFunctions ().push_back (thunk);
8335
- return thunk;
8336
+ return { thunk, interfaceSubs} ;
8336
8337
}
8337
8338
8338
8339
std::pair<SILFunction *, SubstitutionMap>
@@ -8472,22 +8473,25 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction(
8472
8473
auto linearMapTargetType = targetType->getResults ().back ().getSILStorageType ()
8473
8474
.castTo <SILFunctionType>();
8474
8475
8475
- auto *innerThunk = getOrCreateSubsetParametersThunkForLinearMap (
8476
- thunk, linearMapType, linearMapTargetType, kind,
8477
- desiredIndices, actualIndices);
8476
+ SILFunction *linearMapThunk;
8477
+ SubstitutionMap linearMapSubs;
8478
+ std::tie (linearMapThunk, linearMapSubs) =
8479
+ getOrCreateSubsetParametersThunkForLinearMap (
8480
+ thunk, linearMapType, linearMapTargetType, kind,
8481
+ desiredIndices, actualIndices);
8478
8482
8479
- auto *innerThunkFRI = builder.createFunctionRef (loc, innerThunk );
8480
- auto *newDerivative = builder.createPartialApply (
8481
- loc, innerThunkFRI, thunk-> getForwardingSubstitutionMap () , {linearMap},
8483
+ auto *linearMapThunkFRI = builder.createFunctionRef (loc, linearMapThunk );
8484
+ auto *thunkedLinearMap = builder.createPartialApply (
8485
+ loc, linearMapThunkFRI, linearMapSubs , {linearMap},
8482
8486
ParameterConvention::Direct_Guaranteed);
8483
8487
8484
8488
assert (origFnType->getResults ().size () == 1 );
8485
8489
if (origFnType->getResults ().front ().isFormalDirect ()) {
8486
8490
auto result = joinElements (
8487
- {originalDirectResult, newDerivative }, builder, loc);
8491
+ {originalDirectResult, thunkedLinearMap }, builder, loc);
8488
8492
builder.createReturn (loc, result);
8489
8493
} else {
8490
- builder.createReturn (loc, newDerivative );
8494
+ builder.createReturn (loc, thunkedLinearMap );
8491
8495
}
8492
8496
8493
8497
getGeneratedFunctions ().push_back (thunk);
0 commit comments