Skip to content

Commit 2f89554

Browse files
authored
[AutoDiff] Fix SILGen JVP/VJP thunking bug. (#26448)
SILGen JVP/VJP thunks for methods now always perform self reordering. Add leak checking test. Previously, self reordering was not performed if actual JVP/VJP type matched expected JVP/VJP type, which is incorrect. Gardening included: - Change `SILFunction *SILGenFunction::getOrCreateAutoDiffLinearMapThunk` to `ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap`. The new helper is more ergonomic and directly returns thunked values. - Use ad-hoc `_vtable_entry_thunk` suffix for JVP/VJP vtable entry thunks. - Add todo comments for TF-685: principled AD thunk mangling. Resolves TF-698.
1 parent 499c875 commit 2f89554

File tree

10 files changed

+209
-137
lines changed

10 files changed

+209
-137
lines changed

lib/SIL/SILFunctionType.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature(
157157
CanGenericSignature assocFnGenSig,
158158
ArrayRef<SILParameterInfo> originalParameters,
159159
AutoDiffIndexSubset *parameterIndices, SILModule &module) {
160-
// If associated function has no
161160
if (!assocFnGenSig)
162161
return nullptr;
163162
auto &ctx = module.getASTContext();

lib/SILGen/SILGen.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -810,13 +810,26 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
810810
indices.parameters, indices.source, /*differentiationOrder*/ 1,
811811
AutoDiffAssociatedFunctionKind::VJP, M, lookUpConformance);
812812

813+
// Self reordering is necessary if wrt at least two parameters, including
814+
// self.
815+
auto shouldReorderSelf = [&]() {
816+
if (!F->hasSelfParam())
817+
return false;
818+
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
819+
if (!indices.isWrtParameter(selfParamIndex))
820+
return false;
821+
return indices.parameters->getNumIndices() > 1;
822+
};
823+
bool reorderSelf = shouldReorderSelf();
824+
813825
// Thunk JVP method, if it is defined.
814826
if (auto *jvpDecl = diffAttr->getJVPFunction()) {
815827
SILFunction *jvpThunk;
816828
auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
817-
if (jvpFn->getLoweredFunctionType() != expectedJVPType) {
829+
if (reorderSelf || jvpFn->getLoweredFunctionType() != expectedJVPType) {
818830
jvpThunk = getOrCreateAutoDiffAssociatedFunctionThunk(
819-
F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP);
831+
F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP,
832+
reorderSelf);
820833
} else {
821834
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
822835
AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1,
@@ -831,9 +844,10 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
831844
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
832845
SILFunction *vjpThunk;
833846
auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
834-
if (vjpFn->getLoweredFunctionType() != expectedVJPType) {
847+
if (reorderSelf || vjpFn->getLoweredFunctionType() != expectedVJPType) {
835848
vjpThunk = getOrCreateAutoDiffAssociatedFunctionThunk(
836-
F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP);
849+
F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP,
850+
reorderSelf);
837851
} else {
838852
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
839853
AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1,

lib/SILGen/SILGen.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
187187
/// - The last result in the returned pullback.
188188
SILFunction *getOrCreateAutoDiffAssociatedFunctionThunk(
189189
SILFunction *original, SILAutoDiffIndices &indices,
190-
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind);
190+
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
191+
bool reorderSelf);
191192

192193
/// Determine whether the given class has any instance variables that
193194
/// need to be destroyed.

lib/SILGen/SILGenFunction.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,14 +1780,15 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
17801780
// Differentiation thunks
17811781
//===--------------------------------------------------------------------===//
17821782

1783-
/// Get or create a thunk for reabstracting differentials/pullbacks returned
1784-
/// by user-defined JVP/VJP functions.
1783+
/// Get or create a thunk for reabstracting and self-reordering
1784+
/// differentials/pullbacks returned by user-defined JVP/VJP functions, and
1785+
/// apply it to the given differential/pullback.
17851786
///
17861787
/// If `reorderSelf` is true, reorder self so that it appears as:
17871788
/// - The last parameter, for differentials.
17881789
/// - The last result, for pullbacks.
1789-
SILFunction *getOrCreateAutoDiffLinearMapThunk(
1790-
AutoDiffAssociatedFunctionKind assocFnKind,
1790+
ManagedValue getThunkedAutoDiffLinearMap(
1791+
ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind,
17911792
CanSILFunctionType fromType, CanSILFunctionType toType,
17921793
bool reorderSelf);
17931794

lib/SILGen/SILGenPoly.cpp

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3362,7 +3362,6 @@ static CanSILFunctionType buildWithoutActuallyEscapingThunkType(
33623362
/// tuple. Otherwise, add this value directly to `result`.
33633363
static void extractAllElements(SILValue val, SILBuilder &builder,
33643364
SmallVectorImpl<SILValue> &result) {
3365-
// auto &fn = builder.getFunction();
33663365
if (auto tupleType = val->getType().getAs<TupleType>())
33673366
for (auto i : range(tupleType->getNumElements()))
33683367
result.push_back(builder.createTupleExtract(val.getLoc(), val, i));
@@ -3383,10 +3382,11 @@ static SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
33833382

33843383
// SWIFT_ENABLE_TENSORFLOW
33853384
/// Adapted from `SILGenModule::getOrCreateReabstractionThunk`.
3386-
SILFunction *
3387-
SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
3388-
AutoDiffAssociatedFunctionKind assocFnKind, CanSILFunctionType fromType,
3389-
CanSILFunctionType toType, bool reorderSelf) {
3385+
ManagedValue
3386+
SILGenFunction::getThunkedAutoDiffLinearMap(
3387+
ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind,
3388+
CanSILFunctionType fromType, CanSILFunctionType toType,
3389+
bool reorderSelf) {
33903390
// Compute the thunk type.
33913391
SubstitutionMap interfaceSubs;
33923392
GenericEnvironment *genericEnv = nullptr;
@@ -3407,7 +3407,7 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
34073407
std::string name = mangler.mangleReabstractionThunkHelper(
34083408
thunkType, fromInterfaceType, toInterfaceType,
34093409
Type(), getModule().getSwiftModule());
3410-
// TODO: Use principled mangling.
3410+
// TODO(TF-685): Use principled thunk mangling.
34113411
if (reorderSelf) {
34123412
switch (assocFnKind) {
34133413
case AutoDiffAssociatedFunctionKind::JVP:
@@ -3426,8 +3426,20 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
34263426
auto *thunk = fb.getOrCreateSharedFunction(
34273427
loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
34283428
ProfileCounter(), IsReabstractionThunk, IsNotDynamic);
3429+
3430+
// Partially-apply the thunk to `linearMap` and return the thunked value.
3431+
auto getThunkedResult = [&]() {
3432+
auto thunkedFn = createPartialApplyOfThunk(
3433+
*this, loc, thunk, interfaceSubs, dynamicSelfType, toType, linearMap);
3434+
if (!toType->isNoEscape())
3435+
return thunkedFn;
3436+
// Handle escaping to noescape conversion.
3437+
return B.createConvertEscapeToNoEscape(
3438+
loc, thunkedFn, SILType::getPrimitiveObjectType(toType));
3439+
};
3440+
34293441
if (!thunk->empty())
3430-
return thunk;
3442+
return getThunkedResult();
34313443
thunk->setGenericEnvironment(genericEnv);
34323444
thunk->setOwnershipEliminated();
34333445

@@ -3558,9 +3570,9 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
35583570
arguments.push_back(load);
35593571
}
35603572

3561-
auto linearMap = thunk->getArgumentsWithoutIndirectResults().back();
3573+
auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults().back();
35623574
auto *apply = thunkSGF.B.createApply(
3563-
loc, linearMap, SubstitutionMap(), arguments, /*isNonThrowing*/ false);
3575+
loc, linearMapArg, SubstitutionMap(), arguments, /*isNonThrowing*/ false);
35643576

35653577
// Get return elements.
35663578
SmallVector<SILValue, 4> results;
@@ -3628,7 +3640,7 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
36283640

36293641
// Create return.
36303642
thunkSGF.B.createReturn(loc, retVal);
3631-
return thunk;
3643+
return getThunkedResult();
36323644
}
36333645

36343646
/// Forward function arguments, converting ownership.
@@ -3685,9 +3697,12 @@ static void forwardFunctionArgumentsConvertingOwnership(
36853697
SILFunction *
36863698
SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
36873699
SILFunction *original, SILAutoDiffIndices &indices,
3688-
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind) {
3700+
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
3701+
bool reorderSelf) {
36893702
auto assocFnType = assocFn->getLoweredFunctionType();
36903703

3704+
// TODO(TF-685): Use principled thunk mangling.
3705+
// Do not simply reuse reabstraction thunk mangling.
36913706
Mangle::ASTMangler mangler;
36923707
auto name = getASTContext().getIdentifier(
36933708
mangler.mangleAutoDiffAssociatedFunctionHelper(
@@ -3744,8 +3759,8 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
37443759

37453760
SmallVector<SILValue, 8> directResults;
37463761
extractAllElements(apply, thunkSGF.B, directResults);
3747-
auto linearMap = directResults.back();
3748-
auto linearMapFnType = linearMap->getType().castTo<SILFunctionType>();
3762+
auto linearMap = ManagedValue::forBorrowedObjectRValue(directResults.back());
3763+
auto linearMapFnType = linearMap.getType().castTo<SILFunctionType>();
37493764
auto targetLinearMapFnType = thunk->mapTypeIntoContext(
37503765
origAssocFnType->getResults().back().getSILStorageType())
37513766
.castTo<SILFunctionType>();
@@ -3767,33 +3782,17 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
37673782
thunkSGF.B.createReturn(loc, retValue);
37683783
};
37693784

3770-
// If linear map types are unchanged, return the `apply` instruction.
3771-
if (linearMapFnType == targetLinearMapFnType) {
3785+
// If self ordering is not necessary and linear map types are unchanged,
3786+
// return the `apply` instruction.
3787+
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
37723788
createReturn(apply);
37733789
return thunk;
37743790
}
37753791

3776-
// Generate linear map thunk for reabstraction/self reordering.
3777-
auto shouldReorderSelf = [&]() {
3778-
if (!original->hasSelfParam())
3779-
return false;
3780-
auto selfParamIndex =
3781-
original->getArgumentsWithoutIndirectResults().size() - 1;
3782-
if (!indices.isWrtParameter(selfParamIndex))
3783-
return false;
3784-
return indices.parameters->getNumIndices() > 1;
3785-
};
3786-
bool reorderSelf = shouldReorderSelf();
3787-
auto *linearMapThunk = thunkSGF.getOrCreateAutoDiffLinearMapThunk(
3788-
assocFnKind, linearMapFnType, targetLinearMapFnType, reorderSelf);
3789-
auto linearMapThunkValue =
3790-
thunkSGF.B.createFunctionRefFor(loc, linearMapThunk);
3791-
SubstitutionMap linearMapSubs;
3792-
if (linearMapThunk->getLoweredFunctionType()->isPolymorphic())
3793-
linearMapSubs = thunk->getForwardingSubstitutionMap();
3794-
linearMap = thunkSGF.B.createPartialApply(
3795-
loc, linearMapThunkValue, linearMapSubs, {linearMap},
3796-
linearMapFnType->getCalleeConvention());
3792+
// Otherwise, apply reabstraction/self reordering thunk to linear map.
3793+
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
3794+
linearMap, assocFnKind, linearMapFnType, targetLinearMapFnType,
3795+
reorderSelf);
37973796

37983797
// Return original results and thunked differential/pullback.
37993798
if (directResults.size() > 1) {
@@ -3802,10 +3801,10 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
38023801
auto originalDirectResult =
38033802
joinElements(originalDirectResults, thunkSGF.B, apply.getLoc());
38043803
auto thunkResult = joinElements(
3805-
{originalDirectResult, linearMap}, thunkSGF.B, loc);
3804+
{originalDirectResult, linearMap.getValue()}, thunkSGF.B, loc);
38063805
createReturn(thunkResult);
38073806
} else {
3808-
createReturn(linearMap);
3807+
createReturn(linearMap.getValue());
38093808
}
38103809
return thunk;
38113810
}

lib/SILGen/SILGenThunk.cpp

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -84,28 +84,27 @@ SILGenModule::getOrCreateAutoDiffThunk(SILDeclRef assocFnDeclRef,
8484
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
8585
originalLinkage, /*isAssocFnExported*/ true);
8686
auto name = assocFnDeclRef.mangle();
87-
auto *F = builder.getOrCreateFunction(
87+
auto *thunk = builder.getOrCreateFunction(
8888
assocFnDecl, name, linkage, assocFnTy, IsBare, IsTransparent,
8989
assocFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(), IsThunk);
90-
91-
if (F->empty()) {
92-
if (auto genSig = assocFnTy->getGenericSignature())
93-
F->setGenericEnvironment(genSig->createGenericEnvironment());
94-
SILGenFunction SGF(*this, *F, SwiftModule);
95-
SmallVector<ManagedValue, 4> params;
96-
auto loc = assocFnDeclRef.getAsRegularLocation();
97-
SGF.collectThunkParams(loc, params);
98-
auto assocFnRef = SGF.B.createFunctionRef(loc, assocFn);
99-
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(assocFnTy);
100-
SmallVector<SILValue, 4> args(F->getArguments().begin(),
101-
F->getArguments().end());
102-
auto apply = SGF.emitApplyWithRethrow(
103-
loc, assocFnRef, autoDiffAssocFnSILTy,
104-
SGF.getForwardingSubstitutionMap(), args);
105-
SGF.B.createReturn(loc, apply);
106-
}
107-
108-
return F;
90+
if (!thunk->empty())
91+
return thunk;
92+
93+
if (auto genSig = assocFnTy->getGenericSignature())
94+
thunk->setGenericEnvironment(genSig->createGenericEnvironment());
95+
SILGenFunction SGF(*this, *thunk, SwiftModule);
96+
SmallVector<ManagedValue, 4> params;
97+
auto loc = assocFnDeclRef.getAsRegularLocation();
98+
SGF.collectThunkParams(loc, params);
99+
auto assocFnRef = SGF.B.createFunctionRef(loc, assocFn);
100+
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(assocFnTy);
101+
SmallVector<SILValue, 4> args(thunk->getArguments().begin(),
102+
thunk->getArguments().end());
103+
auto apply = SGF.emitApplyWithRethrow(
104+
loc, assocFnRef, autoDiffAssocFnSILTy,
105+
SGF.getForwardingSubstitutionMap(), args);
106+
SGF.B.createReturn(loc, apply);
107+
return thunk;
109108
}
110109

111110
// SWIFT_ENABLE_TENSORFLOW
@@ -120,37 +119,38 @@ SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk(
120119
auto originalLinkage = originalFn.getLinkage(ForDefinition);
121120
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
122121
originalLinkage, /*isAssocFnExported*/ true);
123-
auto name = assocFnDeclRef.mangle() + "_thunk";
124-
auto *F = builder.getOrCreateFunction(
122+
// TODO(TF-685): Use principled thunk mangling.
123+
// Do not simply reuse reabstraction thunk mangling.
124+
auto name = assocFnDeclRef.mangle() + "_vtable_entry_thunk";
125+
auto *thunk = builder.getOrCreateFunction(
125126
assocFnDecl, name, linkage, constantTy, IsBare, IsTransparent,
126127
assocFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(), IsThunk);
127-
128-
if (F->empty()) {
129-
if (auto genSig = constantTy->getGenericSignature())
130-
F->setGenericEnvironment(genSig->createGenericEnvironment());
131-
SILGenFunction SGF(*this, *F, SwiftModule);
132-
SmallVector<ManagedValue, 4> params;
133-
auto loc = assocFnDeclRef.getAsRegularLocation();
134-
SGF.collectThunkParams(loc, params);
135-
auto originalFnRef = SGF.emitGlobalFunctionRef(loc, originalFn);
136-
auto loweredIndices = autoDiffFuncId->getParameterIndices()->getLowered(
137-
SGF.getASTContext(),
138-
assocFnDecl->getInterfaceType()->castTo<AnyFunctionType>());
139-
auto autoDiffFn = SGF.B.createAutoDiffFunction(
140-
loc, loweredIndices, /*differentiationOrder*/ 1, originalFnRef);
141-
auto autoDiffAssocFn = SGF.B.createAutoDiffFunctionExtract(
142-
loc, AutoDiffFunctionExtractInst::Extractee(autoDiffFuncId->getKind()),
143-
/*differentiationOrder*/ 1, autoDiffFn);
144-
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(constantTy);
145-
SmallVector<SILValue, 4> args(F->getArguments().begin(),
146-
F->getArguments().end());
147-
auto apply = SGF.emitApplyWithRethrow(
148-
loc, autoDiffAssocFn, autoDiffAssocFnSILTy,
149-
SGF.getForwardingSubstitutionMap(), args);
150-
SGF.B.createReturn(loc, apply);
151-
}
152-
153-
return F;
128+
if (!thunk->empty())
129+
return thunk;
130+
131+
if (auto genSig = constantTy->getGenericSignature())
132+
thunk->setGenericEnvironment(genSig->createGenericEnvironment());
133+
SILGenFunction SGF(*this, *thunk, SwiftModule);
134+
SmallVector<ManagedValue, 4> params;
135+
auto loc = assocFnDeclRef.getAsRegularLocation();
136+
SGF.collectThunkParams(loc, params);
137+
auto originalFnRef = SGF.emitGlobalFunctionRef(loc, originalFn);
138+
auto loweredIndices = autoDiffFuncId->getParameterIndices()->getLowered(
139+
SGF.getASTContext(),
140+
assocFnDecl->getInterfaceType()->castTo<AnyFunctionType>());
141+
auto autoDiffFn = SGF.B.createAutoDiffFunction(
142+
loc, loweredIndices, /*differentiationOrder*/ 1, originalFnRef);
143+
auto autoDiffAssocFn = SGF.B.createAutoDiffFunctionExtract(
144+
loc, AutoDiffFunctionExtractInst::Extractee(autoDiffFuncId->getKind()),
145+
/*differentiationOrder*/ 1, autoDiffFn);
146+
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(constantTy);
147+
SmallVector<SILValue, 4> args(thunk->getArguments().begin(),
148+
thunk->getArguments().end());
149+
auto apply = SGF.emitApplyWithRethrow(
150+
loc, autoDiffAssocFn, autoDiffAssocFnSILTy,
151+
SGF.getForwardingSubstitutionMap(), args);
152+
SGF.B.createReturn(loc, apply);
153+
return thunk;
154154
}
155155

156156
ManagedValue

test/AutoDiff/method_self_reordering_thunk/main.swift

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)