@@ -1145,9 +1145,11 @@ class ADContext {
1145
1145
// / purposes.
1146
1146
void foldDifferentiableFunctionExtraction (DifferentiableFunctionInst *source);
1147
1147
1148
- // / Get or create a derivative function index subset thunk from
1149
- // / `actualIndices` to `desiredIndices` for the given derivative function
1150
- // / value and original function operand.
1148
+ // / Get or create a derivative function parameter index subset thunk from
1149
+ // / `actualIndices` to `desiredIndices` for the given associated function
1150
+ // / value and original function operand. Returns a pair of the parameter
1151
+ // / index subset thunk and its interface substitution map (used to partially
1152
+ // / apply the thunk).
1151
1153
// / Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear
1152
1154
// / map returned by the derivative function.
1153
1155
std::pair<SILFunction *, SubstitutionMap>
@@ -1156,11 +1158,14 @@ class ADContext {
1156
1158
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
1157
1159
SILAutoDiffIndices actualIndices);
1158
1160
1159
- // / Get or create a derivative function index subset thunk from
1160
- // / `actualIndices` to `desiredIndices` for the given derivative function
1161
- // / value and original function operand.
1162
- SILFunction *getOrCreateSubsetParametersThunkForLinearMap (
1163
- SILFunction *derivativeFn, CanSILFunctionType linearMapType,
1161
+ // / Get or create a derivative function parameter index subset thunk from
1162
+ // / `actualIndices` to `desiredIndices` for the given associated function
1163
+ // / value and original function operand. Returns a pair of the parameter
1164
+ // / index subset thunk and its interface substitution map (used to partially
1165
+ // / apply the thunk).
1166
+ std::pair<SILFunction *, SubstitutionMap>
1167
+ getOrCreateSubsetParametersThunkForLinearMap (
1168
+ SILFunction *assocFn, CanSILFunctionType linearMapType,
1164
1169
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
1165
1170
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);
1166
1171
@@ -8098,7 +8103,7 @@ class Differentiation : public SILModuleTransform {
8098
8103
};
8099
8104
} // end anonymous namespace
8100
8105
8101
- SILFunction *
8106
+ std::pair< SILFunction *, SubstitutionMap>
8102
8107
ADContext::getOrCreateSubsetParametersThunkForLinearMap (
8103
8108
SILFunction *parentThunk, CanSILFunctionType linearMapType,
8104
8109
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
@@ -8107,8 +8112,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
8107
8112
<< " Getting a subset parameters thunk for " << linearMapType
8108
8113
<< " from " << actualIndices << " to " << desiredIndices << ' \n ' );
8109
8114
8110
- SubstitutionMap interfaceSubs = parentThunk-> getForwardingSubstitutionMap () ;
8111
- GenericEnvironment *genericEnv = parentThunk-> getGenericEnvironment () ;
8115
+ SubstitutionMap interfaceSubs;
8116
+ GenericEnvironment *genericEnv = nullptr ;
8112
8117
auto thunkType = buildThunkType (
8113
8118
parentThunk, linearMapType, targetType, genericEnv, interfaceSubs,
8114
8119
/* withoutActuallyEscaping*/ true ,
@@ -8141,7 +8146,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
8141
8146
ProfileCounter (), IsThunk, IsNotDynamic);
8142
8147
8143
8148
if (!thunk->empty ())
8144
- return thunk;
8149
+ return { thunk, interfaceSubs} ;
8145
8150
8146
8151
thunk->setGenericEnvironment (genericEnv);
8147
8152
thunk->setOwnershipEliminated ();
@@ -8289,7 +8294,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
8289
8294
for (auto *alloc : reversed (localAllocations))
8290
8295
builder.createDeallocStack (loc, alloc);
8291
8296
builder.createReturn (loc, ai);
8292
- return thunk;
8297
+ return { thunk, interfaceSubs} ;
8293
8298
}
8294
8299
8295
8300
// If pullback thunk, return only the desired results and clean up the
@@ -8325,7 +8330,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
8325
8330
builder.createReturn (loc, result);
8326
8331
8327
8332
getGeneratedFunctions ().push_back (thunk);
8328
- return thunk;
8333
+ return { thunk, interfaceSubs} ;
8329
8334
}
8330
8335
8331
8336
std::pair<SILFunction *, SubstitutionMap>
@@ -8465,22 +8470,25 @@ ADContext::getOrCreateSubsetParametersThunkForDerivativeFunction(
8465
8470
auto linearMapTargetType = targetType->getResults ().back ().getSILStorageType ()
8466
8471
.castTo <SILFunctionType>();
8467
8472
8468
- auto *innerThunk = getOrCreateSubsetParametersThunkForLinearMap (
8469
- thunk, linearMapType, linearMapTargetType, kind,
8470
- desiredIndices, actualIndices);
8473
+ SILFunction *linearMapThunk;
8474
+ SubstitutionMap linearMapSubs;
8475
+ std::tie (linearMapThunk, linearMapSubs) =
8476
+ getOrCreateSubsetParametersThunkForLinearMap (
8477
+ thunk, linearMapType, linearMapTargetType, kind,
8478
+ desiredIndices, actualIndices);
8471
8479
8472
- auto *innerThunkFRI = builder.createFunctionRef (loc, innerThunk );
8473
- auto *newDerivative = builder.createPartialApply (
8474
- loc, innerThunkFRI, thunk-> getForwardingSubstitutionMap () , {linearMap},
8480
+ auto *linearMapThunkFRI = builder.createFunctionRef (loc, linearMapThunk );
8481
+ auto *thunkedLinearMap = builder.createPartialApply (
8482
+ loc, linearMapThunkFRI, linearMapSubs , {linearMap},
8475
8483
ParameterConvention::Direct_Guaranteed);
8476
8484
8477
8485
assert (origFnType->getResults ().size () == 1 );
8478
8486
if (origFnType->getResults ().front ().isFormalDirect ()) {
8479
8487
auto result = joinElements (
8480
- {originalDirectResult, newDerivative }, builder, loc);
8488
+ {originalDirectResult, thunkedLinearMap }, builder, loc);
8481
8489
builder.createReturn (loc, result);
8482
8490
} else {
8483
- builder.createReturn (loc, newDerivative );
8491
+ builder.createReturn (loc, thunkedLinearMap );
8484
8492
}
8485
8493
8486
8494
getGeneratedFunctions ().push_back (thunk);
0 commit comments