Skip to content

Commit 64dd82a

Browse files
author
marcrasi
authored
[AutoDiff] simplify SILGen thunking and set correct thunk linkage (#28616)
Sets the SILGen'd AD thunk linkage to the same linkage as the custom derivative that they wrap, by doing two things: * Use `customDerivativeFn->getLinkage()` for the linkage of the AD thunk. * Use `ForDefinition` when getting the custom derivative functions. The `NotForDefinition` flag is for creating function declarations with external linkage, so if we use that then `customDerivativeFn->getLinkage()` is an external linkage, which is not right. (This code is probably going to be deleted very soon anyways, when we forbid custom derivatives in `@differentiable` attributes.) Also it was initially hard for me to figure out what was going on, so I made some simplifications: * Delete `getOrCreateAutoDiffDerivativeForwardingThunk` because we can just use the other thunking function, which has a superset of its functionality. * Rename `getOrCreateAutoDiffDerivativeReabstractionThunk` -> `getOrCreateCustomDerivativeThunk` to clarify that it does all thunking necessary for custom derivatives, not just reabstraction. * Make `getOrCreateCustomDerivativeThunk` determine whether self needs reordering, so that the caller does not have to worry about that.
1 parent 02d8d83 commit 64dd82a

File tree

5 files changed

+98
-133
lines changed

5 files changed

+98
-133
lines changed

lib/SILGen/SILGen.cpp

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -769,9 +769,9 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
769769
SILFunction *jvp = nullptr;
770770
SILFunction *vjp = nullptr;
771771
if (auto *jvpDecl = diffAttr->getJVPFunction())
772-
jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
772+
jvp = getFunction(SILDeclRef(jvpDecl), ForDefinition);
773773
if (auto *vjpDecl = diffAttr->getVJPFunction())
774-
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
774+
vjp = getFunction(SILDeclRef(vjpDecl), ForDefinition);
775775
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
776776
assert((!F->getLoweredFunctionType()->getSubstGenericSignature() ||
777777
diffAttr->getDerivativeGenericSignature()) &&
@@ -829,20 +829,6 @@ void SILGenModule::emitDifferentiabilityWitness(
829829
if (origSilFnType->getNumParameters() > silParamIndices->getCapacity())
830830
silParamIndices = silParamIndices->extendingCapacity(
831831
getASTContext(), origSilFnType->getNumParameters());
832-
// TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`.
833-
SILAutoDiffIndices indices(/*source*/ 0, silParamIndices);
834-
835-
// Self reordering thunk is necessary if wrt at least two parameters,
836-
// including self.
837-
auto shouldReorderSelf = [&]() {
838-
if (!originalFunction->hasSelfParam())
839-
return false;
840-
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
841-
if (!indices.isWrtParameter(selfParamIndex))
842-
return false;
843-
return indices.parameters->getNumIndices() > 1;
844-
};
845-
bool reorderSelf = shouldReorderSelf();
846832

847833
// Get or create new SIL differentiability witness.
848834
// Witness already exists when there are two `@derivative` attributes (JVP and
@@ -867,28 +853,8 @@ void SILGenModule::emitDifferentiabilityWitness(
867853
// Set derivative function in differentiability witness.
868854
auto setDerivativeInDifferentiabilityWitness =
869855
[&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
870-
auto expectedDerivativeType =
871-
origSilFnType->getAutoDiffDerivativeFunctionType(
872-
indices.parameters, indices.source, kind, Types,
873-
LookUpConformanceInModule(M.getSwiftModule()));
874-
// Thunk derivative function.
875-
SILFunction *derivativeThunk;
876-
if (reorderSelf ||
877-
derivative->getLoweredFunctionType() != expectedDerivativeType) {
878-
derivativeThunk = getOrCreateAutoDiffDerivativeReabstractionThunk(
879-
originalFunction, silConfig, derivative, kind, reorderSelf);
880-
} else {
881-
// Note: `AutoDiffDerivativeFunctionIdentifier` must be constructed with
882-
// the AST-level parameter indices, not the SIL-level ones.
883-
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
884-
kind, config.parameterIndices, config.derivativeGenericSignature,
885-
getASTContext());
886-
auto origDeclRef = SILDeclRef(originalAFD)
887-
.asForeign(requiresForeignEntryPoint(originalAFD));
888-
derivativeThunk = getOrCreateAutoDiffDerivativeForwardingThunk(
889-
origDeclRef.asAutoDiffDerivativeFunction(id), derivative,
890-
expectedDerivativeType);
891-
}
856+
auto derivativeThunk = getOrCreateCustomDerivativeThunk(
857+
derivative, originalFunction, silConfig, kind);
892858
// Check for existing same derivative.
893859
// TODO(TF-835): Remove condition below and simplify assertion to
894860
// `!diffWitness->getDerivative(kind)` after `@derivative` attribute

lib/SILGen/SILGen.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,6 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
147147
SILFunction *getDynamicThunk(SILDeclRef constant,
148148
CanSILFunctionType constantTy);
149149

150-
// SWIFT_ENABLE_TENSORFLOW
151-
/// Get or create an autodiff derivative function forwarding thunk for the
152-
/// given derivative SILDeclRef, SILFunction, and function type.
153-
/// The thunk simply forwards arguments and returns results: use this when no
154-
/// reabstraction or self reordering is necessary.
155-
SILFunction *getOrCreateAutoDiffDerivativeForwardingThunk(
156-
SILDeclRef derivativeFnRef, SILFunction *derivativeFn,
157-
CanSILFunctionType derivativeFnTy);
158-
159150
// SWIFT_ENABLE_TENSORFLOW
160151
/// Get or create an autodiff derivative function vtable entry thunk for the
161152
/// given SILDeclRef and derivative function type.
@@ -182,8 +173,15 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
182173
CanType dynamicSelfType);
183174

184175
// SWIFT_ENABLE_TENSORFLOW
185-
/// Get or create an autodiff derivative function thunk performing
186-
/// reabstraction and/or self-reordering.
176+
/// Given a user-specified custom derivative, get or create a thunk that calls
177+
/// the custom derivative, and that haswith the abstraction pattern and
178+
/// parameter ordering required for the SIL derivative of the given original
179+
/// function.
180+
///
181+
/// To achieve the required SIL derivative, the thunk may perform any subset
182+
/// of:
183+
/// - Self-reordering.
184+
/// - Reabstraction.
187185
///
188186
/// Self-reordering is done for canonicalizing the types of derivative
189187
/// functions for instance methods wrt self. We want users to define
@@ -223,12 +221,14 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
223221
/// ordering uniform for "wrt self instance method derivatives" and simplifies
224222
/// the transform rules.
225223
///
226-
/// If `reorderSelf` is true, reorder self so that it appears as:
224+
/// If self must be reordered, reorder it so that it appears as:
227225
/// - The last parameter in the returned differential.
228226
/// - The last result in the returned pullback.
229-
SILFunction *getOrCreateAutoDiffDerivativeReabstractionThunk(
230-
SILFunction *original, AutoDiffConfig config, SILFunction *derivativeFn,
231-
AutoDiffDerivativeFunctionKind derivativeFnKind, bool reorderSelf);
227+
SILFunction *
228+
getOrCreateCustomDerivativeThunk(
229+
SILFunction *customDerivativeFn,
230+
SILFunction *originalFn, const AutoDiffConfig &config,
231+
AutoDiffDerivativeFunctionKind kind);
232232

233233
/// Determine whether the given class has any instance variables that
234234
/// need to be destroyed.

lib/SILGen/SILGenPoly.cpp

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3715,66 +3715,65 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
37153715
}
37163716

37173717
// SWIFT_ENABLE_TENSORFLOW
3718-
SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
3719-
SILFunction *original, AutoDiffConfig config, SILFunction *derivativeFn,
3720-
AutoDiffDerivativeFunctionKind derivativeFnKind, bool reorderSelf) {
3721-
auto derivativeFnType = derivativeFn->getLoweredFunctionType();
3722-
3723-
// TODO(TF-685): Use principled thunk mangling.
3724-
// Do not simply reuse reabstraction thunk mangling.
3725-
Mangle::ASTMangler mangler;
3726-
auto name = getASTContext()
3727-
.getIdentifier(mangler.mangleAutoDiffDerivativeFunctionHelper(
3728-
original->getName(), derivativeFnKind, config))
3729-
.str();
3730-
auto *thunkGenericEnv = derivativeFnType->getSubstGenericSignature()
3731-
? derivativeFnType->getSubstGenericSignature()->getGenericEnvironment()
3718+
SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
3719+
SILFunction *customDerivativeFn, SILFunction *originalFn,
3720+
const AutoDiffConfig &config, AutoDiffDerivativeFunctionKind kind) {
3721+
auto indices = config.getSILAutoDiffIndices();
3722+
3723+
auto customDerivativeFnTy = customDerivativeFn->getLoweredFunctionType();
3724+
auto *thunkGenericEnv = customDerivativeFnTy->getSubstGenericSignature()
3725+
? customDerivativeFnTy->getSubstGenericSignature()->getGenericEnvironment()
37323726
: nullptr;
37333727

3734-
auto origFnType = original->getLoweredFunctionType();
3735-
assert(config.resultIndices->getNumIndices() == 1 &&
3736-
"Only single result index is currently supported");
3728+
auto origFnTy = originalFn->getLoweredFunctionType();
37373729
CanGenericSignature derivativeCanGenSig;
37383730
if (auto derivativeGenSig = config.derivativeGenericSignature)
37393731
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
3740-
auto origDerivativeFnType = origFnType->getAutoDiffDerivativeFunctionType(
3741-
config.parameterIndices, *config.resultIndices->getIndices().begin(),
3742-
derivativeFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()),
3732+
auto thunkFnTy = origFnTy->getAutoDiffDerivativeFunctionType(
3733+
indices.parameters, indices.source,
3734+
kind, Types, LookUpConformanceInModule(M.getSwiftModule()),
37433735
derivativeCanGenSig);
3744-
assert(!origDerivativeFnType->getExtInfo().hasContext());
3736+
assert(!thunkFnTy->getExtInfo().hasContext());
3737+
3738+
// TODO(TF-685): Use principled thunk mangling.
3739+
// Do not simply reuse reabstraction thunk mangling.
3740+
Mangle::ASTMangler mangler;
3741+
auto name = getASTContext().getIdentifier(
3742+
mangler.mangleAutoDiffDerivativeFunctionHelper(
3743+
originalFn->getName(), kind, config)).str();
37453744

3746-
auto loc = derivativeFn->getLocation();
3745+
auto loc = customDerivativeFn->getLocation();
37473746
SILGenFunctionBuilder fb(*this);
37483747
// This thunk is publicly exposed and cannot be transparent.
37493748
// Instead, mark it as "always inline" for optimization.
37503749
auto *thunk = fb.getOrCreateFunction(
3751-
loc, name, original->getLinkage(), origDerivativeFnType, IsBare,
3752-
IsNotTransparent, derivativeFn->isSerialized(),
3753-
derivativeFn->isDynamicallyReplaceable(), derivativeFn->getEntryCount(),
3754-
derivativeFn->isThunk(), derivativeFn->getClassSubclassScope());
3750+
loc, name, customDerivativeFn->getLinkage(), thunkFnTy, IsBare,
3751+
IsNotTransparent, customDerivativeFn->isSerialized(),
3752+
customDerivativeFn->isDynamicallyReplaceable(), customDerivativeFn->getEntryCount(),
3753+
IsThunk, customDerivativeFn->getClassSubclassScope());
37553754
thunk->setInlineStrategy(AlwaysInline);
37563755
if (!thunk->empty())
37573756
return thunk;
37583757
thunk->setGenericEnvironment(thunkGenericEnv);
37593758

3760-
SILGenFunction thunkSGF(*this, *thunk, derivativeFn->getDeclContext());
3759+
SILGenFunction thunkSGF(*this, *thunk, customDerivativeFn->getDeclContext());
37613760
SmallVector<ManagedValue, 4> params;
37623761
SmallVector<SILArgument *, 4> indirectResults;
37633762
thunkSGF.collectThunkParams(loc, params, &indirectResults);
37643763

3765-
auto *derivativeFnRef = thunkSGF.B.createFunctionRef(loc, derivativeFn);
3766-
auto derivativeFnRefType =
3767-
derivativeFnRef->getType().castTo<SILFunctionType>();
3764+
auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn);
3765+
auto fnRefType =
3766+
fnRef->getType().castTo<SILFunctionType>();
37683767

37693768
// Collect thunk arguments, converting ownership.
37703769
SmallVector<SILValue, 8> arguments;
37713770
for (auto *indRes : indirectResults)
37723771
arguments.push_back(indRes);
3773-
forwardFunctionArguments(thunkSGF, loc, derivativeFnRefType, params,
3772+
forwardFunctionArguments(thunkSGF, loc, fnRefType, params,
37743773
arguments);
37753774
// Apply function argument.
37763775
auto apply = thunkSGF.emitApplyWithRethrow(
3777-
loc, derivativeFnRef, /*substFnType*/ derivativeFnRef->getType(),
3776+
loc, fnRef, /*substFnType*/ fnRef->getType(),
37783777
thunk->getForwardingSubstitutionMap(), arguments);
37793778

37803779
// Create return instruction in the thunk, first deallocating local
@@ -3787,15 +3786,27 @@ SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
37873786
thunkSGF.B.createReturn(loc, retValue);
37883787
};
37893788

3789+
// Self reordering thunk is necessary if wrt at least two parameters,
3790+
// including self.
3791+
auto shouldReorderSelf = [&]() {
3792+
if (!originalFn->hasSelfParam())
3793+
return false;
3794+
auto selfParamIndex = origFnTy->getNumParameters() - 1;
3795+
if (!indices.isWrtParameter(selfParamIndex))
3796+
return false;
3797+
return indices.parameters->getNumIndices() > 1;
3798+
};
3799+
bool reorderSelf = shouldReorderSelf();
3800+
37903801
// If self ordering is not necessary and linear map types are unchanged,
37913802
// return the `apply` instruction.
37923803
auto linearMapFnType = cast<SILFunctionType>(
37933804
thunk
37943805
->mapTypeIntoContext(
3795-
derivativeFnRefType->getResults().back().getInterfaceType())
3806+
fnRefType->getResults().back().getInterfaceType())
37963807
->getCanonicalType());
37973808
auto targetLinearMapFnType = thunk->mapTypeIntoContext(
3798-
origDerivativeFnType->getResults().back().getSILStorageInterfaceType())
3809+
thunkFnTy->getResults().back().getSILStorageInterfaceType())
37993810
.castTo<SILFunctionType>();
38003811
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
38013812
createReturn(apply);
@@ -3807,7 +3818,7 @@ SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
38073818
extractAllElements(apply, loc, thunkSGF.B, directResults);
38083819
auto linearMap = thunkSGF.emitManagedRValueWithCleanup(directResults.back());
38093820
assert(linearMap.getType().castTo<SILFunctionType>() == linearMapFnType);
3810-
auto linearMapKind = derivativeFnKind.getLinearMapKind();
3821+
auto linearMapKind = kind.getLinearMapKind();
38113822
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
38123823
linearMap, linearMapKind, linearMapFnType, targetLinearMapFnType,
38133824
reorderSelf);

lib/SILGen/SILGenThunk.cpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -70,47 +70,6 @@ SILFunction *SILGenModule::getDynamicThunk(SILDeclRef constant,
7070
return F;
7171
}
7272

73-
// SWIFT_ENABLE_TENSORFLOW
74-
SILFunction *
75-
SILGenModule::getOrCreateAutoDiffDerivativeForwardingThunk(
76-
SILDeclRef derivativeFnDeclRef, SILFunction *derivativeFn,
77-
CanSILFunctionType derivativeFnTy) {
78-
auto *autoDiffFuncId =
79-
derivativeFnDeclRef.autoDiffDerivativeFunctionIdentifier;
80-
assert(autoDiffFuncId);
81-
auto *derivativeFnDecl = derivativeFnDeclRef.getDecl();
82-
83-
SILGenFunctionBuilder builder(*this);
84-
auto originalFn = derivativeFnDeclRef.asAutoDiffOriginalFunction();
85-
auto name = derivativeFnDeclRef.mangle();
86-
// This thunk is publicly exposed and cannot be transparent.
87-
// Instead, mark it as "always inline" for optimization.
88-
auto *thunk = builder.getOrCreateFunction(
89-
derivativeFnDecl, name, originalFn.getLinkage(ForDefinition),
90-
derivativeFnTy, IsBare, IsNotTransparent,
91-
derivativeFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(),
92-
IsThunk);
93-
thunk->setInlineStrategy(AlwaysInline);
94-
if (!thunk->empty())
95-
return thunk;
96-
97-
if (auto genSig = derivativeFnTy->getSubstGenericSignature())
98-
thunk->setGenericEnvironment(genSig->getGenericEnvironment());
99-
SILGenFunction SGF(*this, *thunk, SwiftModule);
100-
SmallVector<ManagedValue, 4> params;
101-
auto loc = derivativeFnDeclRef.getAsRegularLocation();
102-
SGF.collectThunkParams(loc, params);
103-
auto derivativeFnRef = SGF.B.createFunctionRef(loc, derivativeFn);
104-
auto autoDiffDerivativeFnSILTy = SILType::getPrimitiveObjectType(derivativeFnTy);
105-
SmallVector<SILValue, 4> args(thunk->getArguments().begin(),
106-
thunk->getArguments().end());
107-
auto apply = SGF.emitApplyWithRethrow(
108-
loc, derivativeFnRef, autoDiffDerivativeFnSILTy,
109-
SGF.getForwardingSubstitutionMap(), args);
110-
SGF.B.createReturn(loc, apply);
111-
return thunk;
112-
}
113-
11473
// SWIFT_ENABLE_TENSORFLOW
11574
SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk(
11675
SILDeclRef derivativeFnDeclRef, CanSILFunctionType constantTy) {

0 commit comments

Comments
 (0)