Skip to content

Commit 8949d2a

Browse files
author
Marc Rasi
committed
[AutoDiff] simplify SILGen thunking and set correct thunk linkage
1 parent 8946965 commit 8949d2a

File tree

5 files changed

+100
-129
lines changed

5 files changed

+100
-129
lines changed

lib/SILGen/SILGen.cpp

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -770,9 +770,9 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
770770
SILFunction *jvp = nullptr;
771771
SILFunction *vjp = nullptr;
772772
if (auto *jvpDecl = diffAttr->getJVPFunction())
773-
jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
773+
jvp = getFunction(SILDeclRef(jvpDecl), ForDefinition);
774774
if (auto *vjpDecl = diffAttr->getVJPFunction())
775-
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
775+
vjp = getFunction(SILDeclRef(vjpDecl), ForDefinition);
776776
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
777777
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
778778
diffAttr->getDerivativeGenericSignature());
@@ -803,20 +803,8 @@ void SILGenModule::emitDifferentiabilityWitness(
803803
if (origSilFnType->getNumParameters() > loweredParamIndices->getCapacity())
804804
loweredParamIndices = loweredParamIndices->extendingCapacity(
805805
getASTContext(), origSilFnType->getNumParameters());
806-
// TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`.
807-
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
808-
809-
// Self reordering thunk is necessary if wrt at least two parameters,
810-
// including self.
811-
auto shouldReorderSelf = [&]() {
812-
if (!originalFunction->hasSelfParam())
813-
return false;
814-
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
815-
if (!indices.isWrtParameter(selfParamIndex))
816-
return false;
817-
return indices.parameters->getNumIndices() > 1;
818-
};
819-
bool reorderSelf = shouldReorderSelf();
806+
AutoDiffConfig loweredConfig = config;
807+
loweredConfig.parameterIndices = loweredParamIndices;
820808

821809
// Create new SIL differentiability witness.
822810
// Witness JVP and VJP are set below.
@@ -830,25 +818,8 @@ void SILGenModule::emitDifferentiabilityWitness(
830818
// Set derivative function in differentiability witness.
831819
auto setDerivativeInDifferentiabilityWitness =
832820
[&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
833-
auto expectedDerivativeType =
834-
origSilFnType->getAutoDiffDerivativeFunctionType(
835-
indices.parameters, indices.source, kind, Types,
836-
LookUpConformanceInModule(M.getSwiftModule()));
837-
// Thunk derivative function.
838-
SILFunction *derivativeThunk;
839-
if (reorderSelf ||
840-
derivative->getLoweredFunctionType() != expectedDerivativeType) {
841-
derivativeThunk = getOrCreateAutoDiffDerivativeReabstractionThunk(
842-
originalFunction, indices, derivative, kind, reorderSelf);
843-
} else {
844-
// Note: `AutoDiffDerivativeFunctionIdentifier` must be constructed with
845-
// the AST-level parameter indices, not the SIL-level ones.
846-
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
847-
kind, config.parameterIndices, getASTContext());
848-
derivativeThunk = getOrCreateAutoDiffDerivativeForwardingThunk(
849-
SILDeclRef(originalAFD).asAutoDiffDerivativeFunction(id), derivative,
850-
expectedDerivativeType);
851-
}
821+
auto derivativeThunk = getOrCreateCustomDerivativeThunk(
822+
derivative, originalFunction, loweredConfig, kind);
852823
// Check for existing same derivative.
853824
// TODO(TF-835): Remove condition below and simplify assertion to
854825
// `!diffWitness->getDerivative(kind)` after `@derivative` attribute

lib/SILGen/SILGen.h

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,6 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
147147
SILFunction *getDynamicThunk(SILDeclRef constant,
148148
CanSILFunctionType constantTy);
149149

150-
// SWIFT_ENABLE_TENSORFLOW
151-
/// Get or create an autodiff derivative function forwarding thunk for the
152-
/// given derivative SILDeclRef, SILFunction, and function type.
153-
/// The thunk simply forwards arguments and returns results: use this when no
154-
/// reabstraction or self reordering is necessary.
155-
SILFunction *getOrCreateAutoDiffDerivativeForwardingThunk(
156-
SILDeclRef derivativeFnRef, SILFunction *derivativeFn,
157-
CanSILFunctionType derivativeFnTy);
158-
159150
// SWIFT_ENABLE_TENSORFLOW
160151
/// Get or create an autodiff derivative function vtable entry thunk for the
161152
/// given SILDeclRef and derivative function type.
@@ -182,8 +173,15 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
182173
CanType dynamicSelfType);
183174

184175
// SWIFT_ENABLE_TENSORFLOW
185-
/// Get or create an autodiff derivative function thunk performing
186-
/// reabstraction and/or self-reordering.
176+
/// Given a user-specified custom derivative, get or create a thunk that calls
177+
/// the custom derivative, and that haswith the abstraction pattern and
178+
/// parameter ordering required for the SIL derivative of the given original
179+
/// function.
180+
///
181+
/// To achieve the required SIL derivative, the thunk may perform any subset
182+
/// of:
183+
/// - Self-reordering.
184+
/// - Reabstraction.
187185
///
188186
/// Self-reordering is done for canonicalizing the types of derivative
189187
/// functions for instance methods wrt self. We want users to define
@@ -223,13 +221,14 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
223221
/// ordering uniform for "wrt self instance method derivatives" and simplifies
224222
/// the transform rules.
225223
///
226-
/// If `reorderSelf` is true, reorder self so that it appears as:
224+
/// If self must be reordered, reorder it so that it appears as:
227225
/// - The last parameter in the returned differential.
228226
/// - The last result in the returned pullback.
229-
SILFunction *getOrCreateAutoDiffDerivativeReabstractionThunk(
230-
SILFunction *original, SILAutoDiffIndices &indices,
231-
SILFunction *derivativeFn,
232-
AutoDiffDerivativeFunctionKind derivativeFnKind, bool reorderSelf);
227+
SILFunction *
228+
getOrCreateCustomDerivativeThunk(
229+
SILFunction *customDerivativeFn,
230+
SILFunction *originalFn, const AutoDiffConfig &config,
231+
AutoDiffDerivativeFunctionKind kind);
233232

234233
/// Determine whether the given class has any instance variables that
235234
/// need to be destroyed.

lib/SILGen/SILGenPoly.cpp

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3691,64 +3691,65 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
36913691

36923692
// SWIFT_ENABLE_TENSORFLOW
36933693
SILFunction *
3694-
SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
3695-
SILFunction *original, SILAutoDiffIndices &indices,
3696-
SILFunction *derivativeFn, AutoDiffDerivativeFunctionKind derivativeFnKind,
3697-
bool reorderSelf) {
3698-
auto derivativeFnType = derivativeFn->getLoweredFunctionType();
3694+
SILGenModule::getOrCreateCustomDerivativeThunk(
3695+
SILFunction *customDerivativeFn,
3696+
SILFunction *originalFn, const AutoDiffConfig &config,
3697+
AutoDiffDerivativeFunctionKind kind) {
3698+
auto indices = config.getSILAutoDiffIndices();
3699+
auto customDerivativeFnTy = customDerivativeFn->getLoweredFunctionType();
3700+
3701+
Lowering::GenericContextScope genericContextScope(
3702+
Types, customDerivativeFnTy->getSubstGenericSignature());
3703+
auto *thunkGenericEnv = customDerivativeFnTy->getSubstGenericSignature()
3704+
? customDerivativeFnTy->getSubstGenericSignature()->getGenericEnvironment()
3705+
: nullptr;
3706+
3707+
auto origFnTy = originalFn->getLoweredFunctionType();
3708+
auto thunkFnTy = origFnTy->getAutoDiffDerivativeFunctionType(
3709+
indices.parameters, indices.source,
3710+
kind, Types, LookUpConformanceInModule(M.getSwiftModule()),
3711+
customDerivativeFnTy->getSubstGenericSignature());
3712+
assert(!thunkFnTy->getExtInfo().hasContext());
36993713

37003714
// TODO(TF-685): Use principled thunk mangling.
37013715
// Do not simply reuse reabstraction thunk mangling.
37023716
Mangle::ASTMangler mangler;
37033717
auto name = getASTContext().getIdentifier(
37043718
mangler.mangleAutoDiffDerivativeFunctionHelper(
3705-
original->getName(), derivativeFnKind, indices)).str();
3719+
originalFn->getName(), kind, indices)).str();
37063720

3707-
Lowering::GenericContextScope genericContextScope(
3708-
Types, derivativeFnType->getSubstGenericSignature());
3709-
auto *thunkGenericEnv = derivativeFnType->getSubstGenericSignature()
3710-
? derivativeFnType->getSubstGenericSignature()->getGenericEnvironment()
3711-
: nullptr;
3712-
3713-
auto origFnType = original->getLoweredFunctionType();
3714-
auto origDerivativeFnType = origFnType->getAutoDiffDerivativeFunctionType(
3715-
indices.parameters, indices.source,
3716-
derivativeFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()),
3717-
derivativeFnType->getSubstGenericSignature());
3718-
assert(!origDerivativeFnType->getExtInfo().hasContext());
3719-
3720-
auto loc = derivativeFn->getLocation();
3721+
auto loc = customDerivativeFn->getLocation();
37213722
SILGenFunctionBuilder fb(*this);
37223723
// This thunk is publicly exposed and cannot be transparent.
37233724
// Instead, mark it as "always inline" for optimization.
37243725
auto *thunk = fb.getOrCreateFunction(
3725-
loc, name, original->getLinkage(), origDerivativeFnType, IsBare,
3726-
IsNotTransparent, derivativeFn->isSerialized(),
3727-
derivativeFn->isDynamicallyReplaceable(), derivativeFn->getEntryCount(),
3728-
derivativeFn->isThunk(), derivativeFn->getClassSubclassScope());
3726+
loc, name, customDerivativeFn->getLinkage(), thunkFnTy, IsBare,
3727+
IsNotTransparent, customDerivativeFn->isSerialized(),
3728+
customDerivativeFn->isDynamicallyReplaceable(), customDerivativeFn->getEntryCount(),
3729+
IsThunk, customDerivativeFn->getClassSubclassScope());
37293730
thunk->setInlineStrategy(AlwaysInline);
37303731
if (!thunk->empty())
37313732
return thunk;
37323733
thunk->setGenericEnvironment(thunkGenericEnv);
37333734

3734-
SILGenFunction thunkSGF(*this, *thunk, derivativeFn->getDeclContext());
3735+
SILGenFunction thunkSGF(*this, *thunk, customDerivativeFn->getDeclContext());
37353736
SmallVector<ManagedValue, 4> params;
37363737
SmallVector<SILArgument *, 4> indirectResults;
37373738
thunkSGF.collectThunkParams(loc, params, &indirectResults);
37383739

3739-
auto *derivativeFnRef = thunkSGF.B.createFunctionRef(loc, derivativeFn);
3740-
auto derivativeFnRefType =
3741-
derivativeFnRef->getType().castTo<SILFunctionType>();
3740+
auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn);
3741+
auto fnRefType =
3742+
fnRef->getType().castTo<SILFunctionType>();
37423743

37433744
// Collect thunk arguments, converting ownership.
37443745
SmallVector<SILValue, 8> arguments;
37453746
for (auto *indRes : indirectResults)
37463747
arguments.push_back(indRes);
3747-
forwardFunctionArguments(thunkSGF, loc, derivativeFnRefType, params,
3748+
forwardFunctionArguments(thunkSGF, loc, fnRefType, params,
37483749
arguments);
37493750
// Apply function argument.
37503751
auto apply = thunkSGF.emitApplyWithRethrow(
3751-
loc, derivativeFnRef, /*substFnType*/ derivativeFnRef->getType(),
3752+
loc, fnRef, /*substFnType*/ fnRef->getType(),
37523753
thunk->getForwardingSubstitutionMap(), arguments);
37533754

37543755
// Create return instruction in the thunk, first deallocating local
@@ -3761,15 +3762,27 @@ SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
37613762
thunkSGF.B.createReturn(loc, retValue);
37623763
};
37633764

3765+
// Self reordering thunk is necessary if wrt at least two parameters,
3766+
// including self.
3767+
auto shouldReorderSelf = [&]() {
3768+
if (!originalFn->hasSelfParam())
3769+
return false;
3770+
auto selfParamIndex = origFnTy->getNumParameters() - 1;
3771+
if (!indices.isWrtParameter(selfParamIndex))
3772+
return false;
3773+
return indices.parameters->getNumIndices() > 1;
3774+
};
3775+
bool reorderSelf = shouldReorderSelf();
3776+
37643777
// If self ordering is not necessary and linear map types are unchanged,
37653778
// return the `apply` instruction.
37663779
auto linearMapFnType = cast<SILFunctionType>(
37673780
thunk
37683781
->mapTypeIntoContext(
3769-
derivativeFnRefType->getResults().back().getInterfaceType())
3782+
fnRefType->getResults().back().getInterfaceType())
37703783
->getCanonicalType());
37713784
auto targetLinearMapFnType = thunk->mapTypeIntoContext(
3772-
origDerivativeFnType->getResults().back().getSILStorageInterfaceType())
3785+
thunkFnTy->getResults().back().getSILStorageInterfaceType())
37733786
.castTo<SILFunctionType>();
37743787
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
37753788
createReturn(apply);
@@ -3781,7 +3794,7 @@ SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
37813794
extractAllElements(apply, loc, thunkSGF.B, directResults);
37823795
auto linearMap = thunkSGF.emitManagedRValueWithCleanup(directResults.back());
37833796
assert(linearMap.getType().castTo<SILFunctionType>() == linearMapFnType);
3784-
auto linearMapKind = derivativeFnKind.getLinearMapKind();
3797+
auto linearMapKind = kind.getLinearMapKind();
37853798
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
37863799
linearMap, linearMapKind, linearMapFnType, targetLinearMapFnType,
37873800
reorderSelf);

lib/SILGen/SILGenThunk.cpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -70,47 +70,6 @@ SILFunction *SILGenModule::getDynamicThunk(SILDeclRef constant,
7070
return F;
7171
}
7272

73-
// SWIFT_ENABLE_TENSORFLOW
74-
SILFunction *
75-
SILGenModule::getOrCreateAutoDiffDerivativeForwardingThunk(
76-
SILDeclRef derivativeFnDeclRef, SILFunction *derivativeFn,
77-
CanSILFunctionType derivativeFnTy) {
78-
auto *autoDiffFuncId =
79-
derivativeFnDeclRef.autoDiffDerivativeFunctionIdentifier;
80-
assert(autoDiffFuncId);
81-
auto *derivativeFnDecl = derivativeFnDeclRef.getDecl();
82-
83-
SILGenFunctionBuilder builder(*this);
84-
auto originalFn = derivativeFnDeclRef.asAutoDiffOriginalFunction();
85-
auto name = derivativeFnDeclRef.mangle();
86-
// This thunk is publicly exposed and cannot be transparent.
87-
// Instead, mark it as "always inline" for optimization.
88-
auto *thunk = builder.getOrCreateFunction(
89-
derivativeFnDecl, name, originalFn.getLinkage(ForDefinition),
90-
derivativeFnTy, IsBare, IsNotTransparent,
91-
derivativeFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(),
92-
IsThunk);
93-
thunk->setInlineStrategy(AlwaysInline);
94-
if (!thunk->empty())
95-
return thunk;
96-
97-
if (auto genSig = derivativeFnTy->getSubstGenericSignature())
98-
thunk->setGenericEnvironment(genSig->getGenericEnvironment());
99-
SILGenFunction SGF(*this, *thunk, SwiftModule);
100-
SmallVector<ManagedValue, 4> params;
101-
auto loc = derivativeFnDeclRef.getAsRegularLocation();
102-
SGF.collectThunkParams(loc, params);
103-
auto derivativeFnRef = SGF.B.createFunctionRef(loc, derivativeFn);
104-
auto autoDiffDerivativeFnSILTy = SILType::getPrimitiveObjectType(derivativeFnTy);
105-
SmallVector<SILValue, 4> args(thunk->getArguments().begin(),
106-
thunk->getArguments().end());
107-
auto apply = SGF.emitApplyWithRethrow(
108-
loc, derivativeFnRef, autoDiffDerivativeFnSILTy,
109-
SGF.getForwardingSubstitutionMap(), args);
110-
SGF.B.createReturn(loc, apply);
111-
return thunk;
112-
}
113-
11473
// SWIFT_ENABLE_TENSORFLOW
11574
SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk(
11675
SILDeclRef derivativeFnDeclRef, CanSILFunctionType constantTy) {

0 commit comments

Comments
 (0)