Skip to content

Commit ced939b

Browse files
authored
[AutoDiff] Fix subset parameters thunk partial_apply substitutions. (#27604)
Fix `partial_apply` substitution map for subset parameters linear map thunk. The correct substitution map is computed by `buildThunkType` in the helper `ADContext::getOrCreateSubsetParametersThunkForLinearMap` and is now returned by the helper. Resolves TF-886.
1 parent 431cc43 commit ced939b

File tree

2 files changed

+40
-22
lines changed

2 files changed

+40
-22
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,9 +1145,11 @@ class ADContext {
11451145
/// purposes.
11461146
void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source);
11471147

1148-
/// Get or create a derivative function index subset thunk from
1149-
/// `actualIndices` to `desiredIndices` for the given derivative function
1150-
/// value and original function operand.
1148+
/// Get or create a derivative function parameter index subset thunk from
1149+
/// `actualIndices` to `desiredIndices` for the given associated function
1150+
/// value and original function operand. Returns a pair of the parameter
1151+
/// index subset thunk and its interface substitution map (used to partially
1152+
/// apply the thunk).
11511153
/// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear
11521154
/// map returned by the derivative function.
11531155
std::pair<SILFunction *, SubstitutionMap>
@@ -1156,11 +1158,14 @@ class ADContext {
11561158
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
11571159
SILAutoDiffIndices actualIndices);
11581160

1159-
/// Get or create a derivative function index subset thunk from
1160-
/// `actualIndices` to `desiredIndices` for the given derivative function
1161-
/// value and original function operand.
1162-
SILFunction *getOrCreateSubsetParametersThunkForLinearMap(
1163-
SILFunction *derivativeFn, CanSILFunctionType linearMapType,
1161+
/// Get or create a derivative function parameter index subset thunk from
1162+
/// `actualIndices` to `desiredIndices` for the given associated function
1163+
/// value and original function operand. Returns a pair of the parameter
1164+
/// index subset thunk and its interface substitution map (used to partially
1165+
/// apply the thunk).
1166+
std::pair<SILFunction *, SubstitutionMap>
1167+
getOrCreateSubsetParametersThunkForLinearMap(
1168+
SILFunction *assocFn, CanSILFunctionType linearMapType,
11641169
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
11651170
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);
11661171

@@ -8098,7 +8103,7 @@ class Differentiation : public SILModuleTransform {
80988103
};
80998104
} // end anonymous namespace
81008105

8101-
SILFunction *
8106+
std::pair<SILFunction *, SubstitutionMap>
81028107
ADContext::getOrCreateSubsetParametersThunkForLinearMap(
81038108
SILFunction *parentThunk, CanSILFunctionType linearMapType,
81048109
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
@@ -8107,8 +8112,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
81078112
<< "Getting a subset parameters thunk for " << linearMapType
81088113
<< " from " << actualIndices << " to " << desiredIndices << '\n');
81098114

8110-
SubstitutionMap interfaceSubs = parentThunk->getForwardingSubstitutionMap();
8111-
GenericEnvironment *genericEnv = parentThunk->getGenericEnvironment();
8115+
SubstitutionMap interfaceSubs;
8116+
GenericEnvironment *genericEnv = nullptr;
81128117
auto thunkType = buildThunkType(
81138118
parentThunk, linearMapType, targetType, genericEnv, interfaceSubs,
81148119
/*withoutActuallyEscaping*/ true,
@@ -8141,7 +8146,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
81418146
ProfileCounter(), IsThunk, IsNotDynamic);
81428147

81438148
if (!thunk->empty())
8144-
return thunk;
8149+
return {thunk, interfaceSubs};
81458150

81468151
thunk->setGenericEnvironment(genericEnv);
81478152
thunk->setOwnershipEliminated();
@@ -8289,7 +8294,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
82898294
for (auto *alloc : reversed(localAllocations))
82908295
builder.createDeallocStack(loc, alloc);
82918296
builder.createReturn(loc, ai);
8292-
return thunk;
8297+
return {thunk, interfaceSubs};
82938298
}
82948299

82958300
// If pullback thunk, return only the desired results and clean up the
@@ -8325,7 +8330,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
83258330
builder.createReturn(loc, result);
83268331

83278332
getGeneratedFunctions().push_back(thunk);
8328-
return thunk;
8333+
return {thunk, interfaceSubs};
83298334
}
83308335

83318336
std::pair<SILFunction *, SubstitutionMap>
@@ -8465,22 +8470,25 @@ ADContext::getOrCreateSubsetParametersThunkForDerivativeFunction(
84658470
auto linearMapTargetType = targetType->getResults().back().getSILStorageType()
84668471
.castTo<SILFunctionType>();
84678472

8468-
auto *innerThunk = getOrCreateSubsetParametersThunkForLinearMap(
8469-
thunk, linearMapType, linearMapTargetType, kind,
8470-
desiredIndices, actualIndices);
8473+
SILFunction *linearMapThunk;
8474+
SubstitutionMap linearMapSubs;
8475+
std::tie(linearMapThunk, linearMapSubs) =
8476+
getOrCreateSubsetParametersThunkForLinearMap(
8477+
thunk, linearMapType, linearMapTargetType, kind,
8478+
desiredIndices, actualIndices);
84718479

8472-
auto *innerThunkFRI = builder.createFunctionRef(loc, innerThunk);
8473-
auto *newDerivative = builder.createPartialApply(
8474-
loc, innerThunkFRI, thunk->getForwardingSubstitutionMap(), {linearMap},
8480+
auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk);
8481+
auto *thunkedLinearMap = builder.createPartialApply(
8482+
loc, linearMapThunkFRI, linearMapSubs, {linearMap},
84758483
ParameterConvention::Direct_Guaranteed);
84768484

84778485
assert(origFnType->getResults().size() == 1);
84788486
if (origFnType->getResults().front().isFormalDirect()) {
84798487
auto result = joinElements(
8480-
{originalDirectResult, newDerivative}, builder, loc);
8488+
{originalDirectResult, thunkedLinearMap}, builder, loc);
84818489
builder.createReturn(loc, result);
84828490
} else {
8483-
builder.createReturn(loc, newDerivative);
8491+
builder.createReturn(loc, thunkedLinearMap);
84848492
}
84858493

84868494
getGeneratedFunctions().push_back(thunk);

test/AutoDiff/generics.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,16 @@ extension TF_817 {
293293
}
294294
}
295295

296+
// TF-886: Test `partial_apply` of linear map subset parameters thunk.
297+
@differentiable
298+
func TF_886_foo<T, U: Differentiable>(_: Float, _: T, _: U) -> Float {
299+
return 0
300+
}
301+
@differentiable
302+
func TF_886_bar<T>(x: Float, y: T) -> Float {
303+
return TF_886_foo(x, y, 0)
304+
}
305+
296306
// Test layout requirements.
297307

298308
// The layout requirement is "contextual": the requirement is not on `T`, the

0 commit comments

Comments
 (0)