Skip to content

Commit 1c7b1aa

Browse files
committed
[AutoDiff] Fix partial_apply substitution map for subset parameters thunk.
Fix `partial_apply` substitution map for subset parameters linear map thunk. The correct substitution map is computed by `buildThunkType` in the helper `ADContext::getOrCreateSubsetParametersThunkForLinearMap` and is now returned by the helper. Resolves TF-886.
1 parent fb6045c commit 1c7b1aa

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,8 @@ class ADContext {
11601160
/// Get or create an associated function index subset thunk from
11611161
/// `actualIndices` to `desiredIndices` for the given associated function
11621162
/// value and original function operand.
1163-
SILFunction *getOrCreateSubsetParametersThunkForLinearMap(
1163+
std::pair<SILFunction *, SubstitutionMap>
1164+
getOrCreateSubsetParametersThunkForLinearMap(
11641165
SILFunction *assocFn, CanSILFunctionType linearMapType,
11651166
CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind,
11661167
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);
@@ -8105,7 +8106,7 @@ class Differentiation : public SILModuleTransform {
81058106
};
81068107
} // end anonymous namespace
81078108

8108-
SILFunction *
8109+
std::pair<SILFunction *, SubstitutionMap>
81098110
ADContext::getOrCreateSubsetParametersThunkForLinearMap(
81108111
SILFunction *parentThunk, CanSILFunctionType linearMapType,
81118112
CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind,
@@ -8114,8 +8115,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
81148115
<< "Getting a subset parameters thunk for " << linearMapType
81158116
<< " from " << actualIndices << " to " << desiredIndices << '\n');
81168117

8117-
SubstitutionMap interfaceSubs = parentThunk->getForwardingSubstitutionMap();
8118-
GenericEnvironment *genericEnv = parentThunk->getGenericEnvironment();
8118+
SubstitutionMap interfaceSubs;
8119+
GenericEnvironment *genericEnv = nullptr;
81198120
auto thunkType = buildThunkType(
81208121
parentThunk, linearMapType, targetType, genericEnv, interfaceSubs,
81218122
/*withoutActuallyEscaping*/ true,
@@ -8148,7 +8149,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
81488149
ProfileCounter(), IsThunk, IsNotDynamic);
81498150

81508151
if (!thunk->empty())
8151-
return thunk;
8152+
return {thunk, interfaceSubs};
81528153

81538154
thunk->setGenericEnvironment(genericEnv);
81548155
thunk->setOwnershipEliminated();
@@ -8296,7 +8297,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
82968297
for (auto *alloc : reversed(localAllocations))
82978298
builder.createDeallocStack(loc, alloc);
82988299
builder.createReturn(loc, ai);
8299-
return thunk;
8300+
return {thunk, interfaceSubs};
83008301
}
83018302

83028303
// If pullback thunk, return only the desired results and clean up the
@@ -8332,7 +8333,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
83328333
builder.createReturn(loc, result);
83338334

83348335
getGeneratedFunctions().push_back(thunk);
8335-
return thunk;
8336+
return {thunk, interfaceSubs};
83368337
}
83378338

83388339
std::pair<SILFunction *, SubstitutionMap>
@@ -8472,22 +8473,25 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction(
84728473
auto linearMapTargetType = targetType->getResults().back().getSILStorageType()
84738474
.castTo<SILFunctionType>();
84748475

8475-
auto *innerThunk = getOrCreateSubsetParametersThunkForLinearMap(
8476-
thunk, linearMapType, linearMapTargetType, kind,
8477-
desiredIndices, actualIndices);
8476+
SILFunction *linearMapThunk;
8477+
SubstitutionMap linearMapSubs;
8478+
std::tie(linearMapThunk, linearMapSubs) =
8479+
getOrCreateSubsetParametersThunkForLinearMap(
8480+
thunk, linearMapType, linearMapTargetType, kind,
8481+
desiredIndices, actualIndices);
84788482

8479-
auto *innerThunkFRI = builder.createFunctionRef(loc, innerThunk);
8480-
auto *newDerivative = builder.createPartialApply(
8481-
loc, innerThunkFRI, thunk->getForwardingSubstitutionMap(), {linearMap},
8483+
auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk);
8484+
auto *thunkedLinearMap = builder.createPartialApply(
8485+
loc, linearMapThunkFRI, linearMapSubs, {linearMap},
84828486
ParameterConvention::Direct_Guaranteed);
84838487

84848488
assert(origFnType->getResults().size() == 1);
84858489
if (origFnType->getResults().front().isFormalDirect()) {
84868490
auto result = joinElements(
8487-
{originalDirectResult, newDerivative}, builder, loc);
8491+
{originalDirectResult, thunkedLinearMap}, builder, loc);
84888492
builder.createReturn(loc, result);
84898493
} else {
8490-
builder.createReturn(loc, newDerivative);
8494+
builder.createReturn(loc, thunkedLinearMap);
84918495
}
84928496

84938497
getGeneratedFunctions().push_back(thunk);

test/AutoDiff/generics.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,16 @@ extension TF_817 {
293293
}
294294
}
295295

296+
// TF-886: Test `partial_apply` of linear map subset parameters thunk.
297+
@differentiable
298+
func TF_886_foo<T, U: Differentiable>(_: Float, _: T, _: U) -> Float {
299+
return 0
300+
}
301+
@differentiable
302+
func TF_886_bar<T>(x: Float, y: T) -> Float {
303+
return TF_886_foo(x, y, 0)
304+
}
305+
296306
// Test layout requirements.
297307

298308
// The layout requirement is "contextual": the requirement is not on `T`, the

0 commit comments

Comments
 (0)