Skip to content

Commit e13e7e4

Browse files
committed
[AutoDiff] Fix SILGen JVP/VJP thunking bug.
SILGen JVP/VJP thunks for methods now always perform self reordering. Add leak checking test. Previously, self reordering was not performed if actual JVP/VJP type matched expected JVP/VJP type, which is incorrect. Gardening included: - Change `SILFunction *SILGenFunction::getOrCreateAutoDiffLinearMapThunk` to `ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap`. The new helper is more ergonomic and directly returns thunked values. - Use ad-hoc `_vtable_entry_thunk` suffix for JVP/VJP vtable entry thunks. - Add todo comments for TF-685: principled AD thunk mangling. Resolves TF-698.
1 parent 79cdb96 commit e13e7e4

File tree

10 files changed

+209
-137
lines changed

10 files changed

+209
-137
lines changed

lib/SIL/SILFunctionType.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature(
157157
CanGenericSignature assocFnGenSig,
158158
ArrayRef<SILParameterInfo> originalParameters,
159159
AutoDiffIndexSubset *parameterIndices, SILModule &module) {
160-
// If associated function has no
161160
if (!assocFnGenSig)
162161
return nullptr;
163162
auto &ctx = module.getASTContext();

lib/SILGen/SILGen.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -810,13 +810,26 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
810810
indices.parameters, indices.source, /*differentiationOrder*/ 1,
811811
AutoDiffAssociatedFunctionKind::VJP, M, lookUpConformance);
812812

813+
// Self reordering is necessary if wrt at least two parameters, including
814+
// self.
815+
auto shouldReorderSelf = [&]() {
816+
if (!F->hasSelfParam())
817+
return false;
818+
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
819+
if (!indices.isWrtParameter(selfParamIndex))
820+
return false;
821+
return indices.parameters->getNumIndices() > 1;
822+
};
823+
bool reorderSelf = shouldReorderSelf();
824+
813825
// Thunk JVP method, if it is defined.
814826
if (auto *jvpDecl = diffAttr->getJVPFunction()) {
815827
SILFunction *jvpThunk;
816828
auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
817-
if (jvpFn->getLoweredFunctionType() != expectedJVPType) {
829+
if (reorderSelf || jvpFn->getLoweredFunctionType() != expectedJVPType) {
818830
jvpThunk = getOrCreateAutoDiffAssociatedFunctionThunk(
819-
F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP);
831+
F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP,
832+
reorderSelf);
820833
} else {
821834
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
822835
AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1,
@@ -831,9 +844,10 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
831844
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
832845
SILFunction *vjpThunk;
833846
auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
834-
if (vjpFn->getLoweredFunctionType() != expectedVJPType) {
847+
if (reorderSelf || vjpFn->getLoweredFunctionType() != expectedVJPType) {
835848
vjpThunk = getOrCreateAutoDiffAssociatedFunctionThunk(
836-
F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP);
849+
F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP,
850+
reorderSelf);
837851
} else {
838852
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
839853
AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1,

lib/SILGen/SILGen.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
187187
/// - The last result in the returned pullback.
188188
SILFunction *getOrCreateAutoDiffAssociatedFunctionThunk(
189189
SILFunction *original, SILAutoDiffIndices &indices,
190-
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind);
190+
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
191+
bool reorderSelf);
191192

192193
/// Determine whether the given class has any instance variables that
193194
/// need to be destroyed.

lib/SILGen/SILGenFunction.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,14 +1780,15 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
17801780
// Differentiation thunks
17811781
//===--------------------------------------------------------------------===//
17821782

1783-
/// Get or create a thunk for reabstracting differentials/pullbacks returned
1784-
/// by user-defined JVP/VJP functions.
1783+
/// Get or create a thunk for reabstracting and self-reordering
1784+
/// differentials/pullbacks returned by user-defined JVP/VJP functions, and
1785+
/// apply it to the given differential/pullback.
17851786
///
17861787
/// If `reorderSelf` is true, reorder self so that it appears as:
17871788
/// - The last parameter, for differentials.
17881789
/// - The last result, for pullbacks.
1789-
SILFunction *getOrCreateAutoDiffLinearMapThunk(
1790-
AutoDiffAssociatedFunctionKind assocFnKind,
1790+
ManagedValue getThunkedAutoDiffLinearMap(
1791+
ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind,
17911792
CanSILFunctionType fromType, CanSILFunctionType toType,
17921793
bool reorderSelf);
17931794

lib/SILGen/SILGenPoly.cpp

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3364,7 +3364,6 @@ static CanSILFunctionType buildWithoutActuallyEscapingThunkType(
33643364
/// tuple. Otherwise, add this value directly to `result`.
33653365
static void extractAllElements(SILValue val, SILBuilder &builder,
33663366
SmallVectorImpl<SILValue> &result) {
3367-
// auto &fn = builder.getFunction();
33683367
if (auto tupleType = val->getType().getAs<TupleType>())
33693368
for (auto i : range(tupleType->getNumElements()))
33703369
result.push_back(builder.createTupleExtract(val.getLoc(), val, i));
@@ -3385,10 +3384,11 @@ static SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
33853384

33863385
// SWIFT_ENABLE_TENSORFLOW
33873386
/// Adapted from `SILGenModule::getOrCreateReabstractionThunk`.
3388-
SILFunction *
3389-
SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
3390-
AutoDiffAssociatedFunctionKind assocFnKind, CanSILFunctionType fromType,
3391-
CanSILFunctionType toType, bool reorderSelf) {
3387+
ManagedValue
3388+
SILGenFunction::getThunkedAutoDiffLinearMap(
3389+
ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind,
3390+
CanSILFunctionType fromType, CanSILFunctionType toType,
3391+
bool reorderSelf) {
33923392
// Compute the thunk type.
33933393
SubstitutionMap interfaceSubs;
33943394
GenericEnvironment *genericEnv = nullptr;
@@ -3409,7 +3409,7 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
34093409
std::string name = mangler.mangleReabstractionThunkHelper(
34103410
thunkType, fromInterfaceType, toInterfaceType,
34113411
Type(), getModule().getSwiftModule());
3412-
// TODO: Use principled mangling.
3412+
// TODO(TF-685): Use principled thunk mangling.
34133413
if (reorderSelf) {
34143414
switch (assocFnKind) {
34153415
case AutoDiffAssociatedFunctionKind::JVP:
@@ -3428,8 +3428,20 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
34283428
auto *thunk = fb.getOrCreateSharedFunction(
34293429
loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
34303430
ProfileCounter(), IsReabstractionThunk, IsNotDynamic);
3431+
3432+
// Partially-apply the thunk to `linearMap` and return the thunked value.
3433+
auto getThunkedResult = [&]() {
3434+
auto thunkedFn = createPartialApplyOfThunk(
3435+
*this, loc, thunk, interfaceSubs, dynamicSelfType, toType, linearMap);
3436+
if (!toType->isNoEscape())
3437+
return thunkedFn;
3438+
// Handle escaping to noescape conversion.
3439+
return B.createConvertEscapeToNoEscape(
3440+
loc, thunkedFn, SILType::getPrimitiveObjectType(toType));
3441+
};
3442+
34313443
if (!thunk->empty())
3432-
return thunk;
3444+
return getThunkedResult();
34333445
thunk->setGenericEnvironment(genericEnv);
34343446
thunk->setOwnershipEliminated();
34353447

@@ -3560,9 +3572,9 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
35603572
arguments.push_back(load);
35613573
}
35623574

3563-
auto linearMap = thunk->getArgumentsWithoutIndirectResults().back();
3575+
auto *linearMapArg = thunk->getArgumentsWithoutIndirectResults().back();
35643576
auto *apply = thunkSGF.B.createApply(
3565-
loc, linearMap, SubstitutionMap(), arguments, /*isNonThrowing*/ false);
3577+
loc, linearMapArg, SubstitutionMap(), arguments, /*isNonThrowing*/ false);
35663578

35673579
// Get return elements.
35683580
SmallVector<SILValue, 4> results;
@@ -3630,7 +3642,7 @@ SILGenFunction::getOrCreateAutoDiffLinearMapThunk(
36303642

36313643
// Create return.
36323644
thunkSGF.B.createReturn(loc, retVal);
3633-
return thunk;
3645+
return getThunkedResult();
36343646
}
36353647

36363648
/// Forward function arguments, converting ownership.
@@ -3687,9 +3699,12 @@ static void forwardFunctionArgumentsConvertingOwnership(
36873699
SILFunction *
36883700
SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
36893701
SILFunction *original, SILAutoDiffIndices &indices,
3690-
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind) {
3702+
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
3703+
bool reorderSelf) {
36913704
auto assocFnType = assocFn->getLoweredFunctionType();
36923705

3706+
// TODO(TF-685): Use principled thunk mangling.
3707+
// Do not simply reuse reabstraction thunk mangling.
36933708
Mangle::ASTMangler mangler;
36943709
auto name = getASTContext().getIdentifier(
36953710
mangler.mangleAutoDiffAssociatedFunctionHelper(
@@ -3746,8 +3761,8 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
37463761

37473762
SmallVector<SILValue, 8> directResults;
37483763
extractAllElements(apply, thunkSGF.B, directResults);
3749-
auto linearMap = directResults.back();
3750-
auto linearMapFnType = linearMap->getType().castTo<SILFunctionType>();
3764+
auto linearMap = ManagedValue::forBorrowedObjectRValue(directResults.back());
3765+
auto linearMapFnType = linearMap.getType().castTo<SILFunctionType>();
37513766
auto targetLinearMapFnType = thunk->mapTypeIntoContext(
37523767
origAssocFnType->getResults().back().getSILStorageType())
37533768
.castTo<SILFunctionType>();
@@ -3769,33 +3784,17 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
37693784
thunkSGF.B.createReturn(loc, retValue);
37703785
};
37713786

3772-
// If linear map types are unchanged, return the `apply` instruction.
3773-
if (linearMapFnType == targetLinearMapFnType) {
3787+
// If self ordering is not necessary and linear map types are unchanged,
3788+
// return the `apply` instruction.
3789+
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
37743790
createReturn(apply);
37753791
return thunk;
37763792
}
37773793

3778-
// Generate linear map thunk for reabstraction/self reordering.
3779-
auto shouldReorderSelf = [&]() {
3780-
if (!original->hasSelfParam())
3781-
return false;
3782-
auto selfParamIndex =
3783-
original->getArgumentsWithoutIndirectResults().size() - 1;
3784-
if (!indices.isWrtParameter(selfParamIndex))
3785-
return false;
3786-
return indices.parameters->getNumIndices() > 1;
3787-
};
3788-
bool reorderSelf = shouldReorderSelf();
3789-
auto *linearMapThunk = thunkSGF.getOrCreateAutoDiffLinearMapThunk(
3790-
assocFnKind, linearMapFnType, targetLinearMapFnType, reorderSelf);
3791-
auto linearMapThunkValue =
3792-
thunkSGF.B.createFunctionRefFor(loc, linearMapThunk);
3793-
SubstitutionMap linearMapSubs;
3794-
if (linearMapThunk->getLoweredFunctionType()->isPolymorphic())
3795-
linearMapSubs = thunk->getForwardingSubstitutionMap();
3796-
linearMap = thunkSGF.B.createPartialApply(
3797-
loc, linearMapThunkValue, linearMapSubs, {linearMap},
3798-
linearMapFnType->getCalleeConvention());
3794+
// Otherwise, apply reabstraction/self reordering thunk to linear map.
3795+
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
3796+
linearMap, assocFnKind, linearMapFnType, targetLinearMapFnType,
3797+
reorderSelf);
37993798

38003799
// Return original results and thunked differential/pullback.
38013800
if (directResults.size() > 1) {
@@ -3804,10 +3803,10 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
38043803
auto originalDirectResult =
38053804
joinElements(originalDirectResults, thunkSGF.B, apply.getLoc());
38063805
auto thunkResult = joinElements(
3807-
{originalDirectResult, linearMap}, thunkSGF.B, loc);
3806+
{originalDirectResult, linearMap.getValue()}, thunkSGF.B, loc);
38083807
createReturn(thunkResult);
38093808
} else {
3810-
createReturn(linearMap);
3809+
createReturn(linearMap.getValue());
38113810
}
38123811
return thunk;
38133812
}

lib/SILGen/SILGenThunk.cpp

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -84,28 +84,27 @@ SILGenModule::getOrCreateAutoDiffThunk(SILDeclRef assocFnDeclRef,
8484
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
8585
originalLinkage, /*isAssocFnExported*/ true);
8686
auto name = assocFnDeclRef.mangle();
87-
auto *F = builder.getOrCreateFunction(
87+
auto *thunk = builder.getOrCreateFunction(
8888
assocFnDecl, name, linkage, assocFnTy, IsBare, IsTransparent,
8989
assocFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(), IsThunk);
90-
91-
if (F->empty()) {
92-
if (auto genSig = assocFnTy->getGenericSignature())
93-
F->setGenericEnvironment(genSig->createGenericEnvironment());
94-
SILGenFunction SGF(*this, *F, SwiftModule);
95-
SmallVector<ManagedValue, 4> params;
96-
auto loc = assocFnDeclRef.getAsRegularLocation();
97-
SGF.collectThunkParams(loc, params);
98-
auto assocFnRef = SGF.B.createFunctionRef(loc, assocFn);
99-
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(assocFnTy);
100-
SmallVector<SILValue, 4> args(F->getArguments().begin(),
101-
F->getArguments().end());
102-
auto apply = SGF.emitApplyWithRethrow(
103-
loc, assocFnRef, autoDiffAssocFnSILTy,
104-
SGF.getForwardingSubstitutionMap(), args);
105-
SGF.B.createReturn(loc, apply);
106-
}
107-
108-
return F;
90+
if (!thunk->empty())
91+
return thunk;
92+
93+
if (auto genSig = assocFnTy->getGenericSignature())
94+
thunk->setGenericEnvironment(genSig->createGenericEnvironment());
95+
SILGenFunction SGF(*this, *thunk, SwiftModule);
96+
SmallVector<ManagedValue, 4> params;
97+
auto loc = assocFnDeclRef.getAsRegularLocation();
98+
SGF.collectThunkParams(loc, params);
99+
auto assocFnRef = SGF.B.createFunctionRef(loc, assocFn);
100+
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(assocFnTy);
101+
SmallVector<SILValue, 4> args(thunk->getArguments().begin(),
102+
thunk->getArguments().end());
103+
auto apply = SGF.emitApplyWithRethrow(
104+
loc, assocFnRef, autoDiffAssocFnSILTy,
105+
SGF.getForwardingSubstitutionMap(), args);
106+
SGF.B.createReturn(loc, apply);
107+
return thunk;
109108
}
110109

111110
// SWIFT_ENABLE_TENSORFLOW
@@ -120,37 +119,38 @@ SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk(
120119
auto originalLinkage = originalFn.getLinkage(ForDefinition);
121120
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
122121
originalLinkage, /*isAssocFnExported*/ true);
123-
auto name = assocFnDeclRef.mangle() + "_thunk";
124-
auto *F = builder.getOrCreateFunction(
122+
// TODO(TF-685): Use principled thunk mangling.
123+
// Do not simply reuse reabstraction thunk mangling.
124+
auto name = assocFnDeclRef.mangle() + "_vtable_entry_thunk";
125+
auto *thunk = builder.getOrCreateFunction(
125126
assocFnDecl, name, linkage, constantTy, IsBare, IsTransparent,
126127
assocFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(), IsThunk);
127-
128-
if (F->empty()) {
129-
if (auto genSig = constantTy->getGenericSignature())
130-
F->setGenericEnvironment(genSig->createGenericEnvironment());
131-
SILGenFunction SGF(*this, *F, SwiftModule);
132-
SmallVector<ManagedValue, 4> params;
133-
auto loc = assocFnDeclRef.getAsRegularLocation();
134-
SGF.collectThunkParams(loc, params);
135-
auto originalFnRef = SGF.emitGlobalFunctionRef(loc, originalFn);
136-
auto loweredIndices = autoDiffFuncId->getParameterIndices()->getLowered(
137-
SGF.getASTContext(),
138-
assocFnDecl->getInterfaceType()->castTo<AnyFunctionType>());
139-
auto autoDiffFn = SGF.B.createAutoDiffFunction(
140-
loc, loweredIndices, /*differentiationOrder*/ 1, originalFnRef);
141-
auto autoDiffAssocFn = SGF.B.createAutoDiffFunctionExtract(
142-
loc, AutoDiffFunctionExtractInst::Extractee(autoDiffFuncId->getKind()),
143-
/*differentiationOrder*/ 1, autoDiffFn);
144-
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(constantTy);
145-
SmallVector<SILValue, 4> args(F->getArguments().begin(),
146-
F->getArguments().end());
147-
auto apply = SGF.emitApplyWithRethrow(
148-
loc, autoDiffAssocFn, autoDiffAssocFnSILTy,
149-
SGF.getForwardingSubstitutionMap(), args);
150-
SGF.B.createReturn(loc, apply);
151-
}
152-
153-
return F;
128+
if (!thunk->empty())
129+
return thunk;
130+
131+
if (auto genSig = constantTy->getGenericSignature())
132+
thunk->setGenericEnvironment(genSig->createGenericEnvironment());
133+
SILGenFunction SGF(*this, *thunk, SwiftModule);
134+
SmallVector<ManagedValue, 4> params;
135+
auto loc = assocFnDeclRef.getAsRegularLocation();
136+
SGF.collectThunkParams(loc, params);
137+
auto originalFnRef = SGF.emitGlobalFunctionRef(loc, originalFn);
138+
auto loweredIndices = autoDiffFuncId->getParameterIndices()->getLowered(
139+
SGF.getASTContext(),
140+
assocFnDecl->getInterfaceType()->castTo<AnyFunctionType>());
141+
auto autoDiffFn = SGF.B.createAutoDiffFunction(
142+
loc, loweredIndices, /*differentiationOrder*/ 1, originalFnRef);
143+
auto autoDiffAssocFn = SGF.B.createAutoDiffFunctionExtract(
144+
loc, AutoDiffFunctionExtractInst::Extractee(autoDiffFuncId->getKind()),
145+
/*differentiationOrder*/ 1, autoDiffFn);
146+
auto autoDiffAssocFnSILTy = SILType::getPrimitiveObjectType(constantTy);
147+
SmallVector<SILValue, 4> args(thunk->getArguments().begin(),
148+
thunk->getArguments().end());
149+
auto apply = SGF.emitApplyWithRethrow(
150+
loc, autoDiffAssocFn, autoDiffAssocFnSILTy,
151+
SGF.getForwardingSubstitutionMap(), args);
152+
SGF.B.createReturn(loc, apply);
153+
return thunk;
154154
}
155155

156156
ManagedValue

test/AutoDiff/method_self_reordering_thunk/main.swift

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)