Skip to content

[AutoDiff] simplify SILGen thunking and set correct thunk linkage #28616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 4 additions & 38 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,9 +769,9 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
SILFunction *jvp = nullptr;
SILFunction *vjp = nullptr;
if (auto *jvpDecl = diffAttr->getJVPFunction())
jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
jvp = getFunction(SILDeclRef(jvpDecl), ForDefinition);
if (auto *vjpDecl = diffAttr->getVJPFunction())
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
vjp = getFunction(SILDeclRef(vjpDecl), ForDefinition);
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
assert((!F->getLoweredFunctionType()->getSubstGenericSignature() ||
diffAttr->getDerivativeGenericSignature()) &&
Expand Down Expand Up @@ -829,20 +829,6 @@ void SILGenModule::emitDifferentiabilityWitness(
if (origSilFnType->getNumParameters() > silParamIndices->getCapacity())
silParamIndices = silParamIndices->extendingCapacity(
getASTContext(), origSilFnType->getNumParameters());
// TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`.
SILAutoDiffIndices indices(/*source*/ 0, silParamIndices);

// Self reordering thunk is necessary if wrt at least two parameters,
// including self.
auto shouldReorderSelf = [&]() {
if (!originalFunction->hasSelfParam())
return false;
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
if (!indices.isWrtParameter(selfParamIndex))
return false;
return indices.parameters->getNumIndices() > 1;
};
bool reorderSelf = shouldReorderSelf();

// Get or create new SIL differentiability witness.
// Witness already exists when there are two `@derivative` attributes (JVP and
Expand All @@ -867,28 +853,8 @@ void SILGenModule::emitDifferentiabilityWitness(
// Set derivative function in differentiability witness.
auto setDerivativeInDifferentiabilityWitness =
[&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
auto expectedDerivativeType =
origSilFnType->getAutoDiffDerivativeFunctionType(
indices.parameters, indices.source, kind, Types,
LookUpConformanceInModule(M.getSwiftModule()));
// Thunk derivative function.
SILFunction *derivativeThunk;
if (reorderSelf ||
derivative->getLoweredFunctionType() != expectedDerivativeType) {
derivativeThunk = getOrCreateAutoDiffDerivativeReabstractionThunk(
originalFunction, silConfig, derivative, kind, reorderSelf);
} else {
// Note: `AutoDiffDerivativeFunctionIdentifier` must be constructed with
// the AST-level parameter indices, not the SIL-level ones.
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
kind, config.parameterIndices, config.derivativeGenericSignature,
getASTContext());
auto origDeclRef = SILDeclRef(originalAFD)
.asForeign(requiresForeignEntryPoint(originalAFD));
derivativeThunk = getOrCreateAutoDiffDerivativeForwardingThunk(
origDeclRef.asAutoDiffDerivativeFunction(id), derivative,
expectedDerivativeType);
}
auto derivativeThunk = getOrCreateCustomDerivativeThunk(
derivative, originalFunction, silConfig, kind);
// Check for existing same derivative.
// TODO(TF-835): Remove condition below and simplify assertion to
// `!diffWitness->getDerivative(kind)` after `@derivative` attribute
Expand Down
30 changes: 15 additions & 15 deletions lib/SILGen/SILGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,6 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
SILFunction *getDynamicThunk(SILDeclRef constant,
CanSILFunctionType constantTy);

// SWIFT_ENABLE_TENSORFLOW
/// Get or create an autodiff derivative function forwarding thunk for the
/// given derivative SILDeclRef, SILFunction, and function type.
/// The thunk simply forwards arguments and returns results: use this when no
/// reabstraction or self reordering is necessary.
SILFunction *getOrCreateAutoDiffDerivativeForwardingThunk(
SILDeclRef derivativeFnRef, SILFunction *derivativeFn,
CanSILFunctionType derivativeFnTy);

// SWIFT_ENABLE_TENSORFLOW
/// Get or create an autodiff derivative function vtable entry thunk for the
/// given SILDeclRef and derivative function type.
Expand All @@ -182,8 +173,15 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
CanType dynamicSelfType);

// SWIFT_ENABLE_TENSORFLOW
/// Get or create an autodiff derivative function thunk performing
/// reabstraction and/or self-reordering.
/// Given a user-specified custom derivative, get or create a thunk that calls
/// the custom derivative, and that haswith the abstraction pattern and
/// parameter ordering required for the SIL derivative of the given original
/// function.
///
/// To achieve the required SIL derivative, the thunk may perform any subset
/// of:
/// - Self-reordering.
/// - Reabstraction.
///
/// Self-reordering is done for canonicalizing the types of derivative
/// functions for instance methods wrt self. We want users to define
Expand Down Expand Up @@ -223,12 +221,14 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
/// ordering uniform for "wrt self instance method derivatives" and simplifies
/// the transform rules.
///
/// If `reorderSelf` is true, reorder self so that it appears as:
/// If self must be reordered, reorder it so that it appears as:
/// - The last parameter in the returned differential.
/// - The last result in the returned pullback.
SILFunction *getOrCreateAutoDiffDerivativeReabstractionThunk(
SILFunction *original, AutoDiffConfig config, SILFunction *derivativeFn,
AutoDiffDerivativeFunctionKind derivativeFnKind, bool reorderSelf);
SILFunction *
getOrCreateCustomDerivativeThunk(
SILFunction *customDerivativeFn,
SILFunction *originalFn, const AutoDiffConfig &config,
AutoDiffDerivativeFunctionKind kind);

/// Determine whether the given class has any instance variables that
/// need to be destroyed.
Expand Down
81 changes: 46 additions & 35 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3715,66 +3715,65 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
}

// SWIFT_ENABLE_TENSORFLOW
SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
SILFunction *original, AutoDiffConfig config, SILFunction *derivativeFn,
AutoDiffDerivativeFunctionKind derivativeFnKind, bool reorderSelf) {
auto derivativeFnType = derivativeFn->getLoweredFunctionType();

// TODO(TF-685): Use principled thunk mangling.
// Do not simply reuse reabstraction thunk mangling.
Mangle::ASTMangler mangler;
auto name = getASTContext()
.getIdentifier(mangler.mangleAutoDiffDerivativeFunctionHelper(
original->getName(), derivativeFnKind, config))
.str();
auto *thunkGenericEnv = derivativeFnType->getSubstGenericSignature()
? derivativeFnType->getSubstGenericSignature()->getGenericEnvironment()
SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
SILFunction *customDerivativeFn, SILFunction *originalFn,
const AutoDiffConfig &config, AutoDiffDerivativeFunctionKind kind) {
auto indices = config.getSILAutoDiffIndices();

auto customDerivativeFnTy = customDerivativeFn->getLoweredFunctionType();
auto *thunkGenericEnv = customDerivativeFnTy->getSubstGenericSignature()
? customDerivativeFnTy->getSubstGenericSignature()->getGenericEnvironment()
: nullptr;

auto origFnType = original->getLoweredFunctionType();
assert(config.resultIndices->getNumIndices() == 1 &&
"Only single result index is currently supported");
auto origFnTy = originalFn->getLoweredFunctionType();
CanGenericSignature derivativeCanGenSig;
if (auto derivativeGenSig = config.derivativeGenericSignature)
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
auto origDerivativeFnType = origFnType->getAutoDiffDerivativeFunctionType(
config.parameterIndices, *config.resultIndices->getIndices().begin(),
derivativeFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()),
auto thunkFnTy = origFnTy->getAutoDiffDerivativeFunctionType(
indices.parameters, indices.source,
kind, Types, LookUpConformanceInModule(M.getSwiftModule()),
derivativeCanGenSig);
assert(!origDerivativeFnType->getExtInfo().hasContext());
assert(!thunkFnTy->getExtInfo().hasContext());

// TODO(TF-685): Use principled thunk mangling.
// Do not simply reuse reabstraction thunk mangling.
Mangle::ASTMangler mangler;
auto name = getASTContext().getIdentifier(
mangler.mangleAutoDiffDerivativeFunctionHelper(
originalFn->getName(), kind, config)).str();

auto loc = derivativeFn->getLocation();
auto loc = customDerivativeFn->getLocation();
SILGenFunctionBuilder fb(*this);
// This thunk is publicly exposed and cannot be transparent.
// Instead, mark it as "always inline" for optimization.
auto *thunk = fb.getOrCreateFunction(
loc, name, original->getLinkage(), origDerivativeFnType, IsBare,
IsNotTransparent, derivativeFn->isSerialized(),
derivativeFn->isDynamicallyReplaceable(), derivativeFn->getEntryCount(),
derivativeFn->isThunk(), derivativeFn->getClassSubclassScope());
loc, name, customDerivativeFn->getLinkage(), thunkFnTy, IsBare,
IsNotTransparent, customDerivativeFn->isSerialized(),
customDerivativeFn->isDynamicallyReplaceable(), customDerivativeFn->getEntryCount(),
IsThunk, customDerivativeFn->getClassSubclassScope());
thunk->setInlineStrategy(AlwaysInline);
if (!thunk->empty())
return thunk;
thunk->setGenericEnvironment(thunkGenericEnv);

SILGenFunction thunkSGF(*this, *thunk, derivativeFn->getDeclContext());
SILGenFunction thunkSGF(*this, *thunk, customDerivativeFn->getDeclContext());
SmallVector<ManagedValue, 4> params;
SmallVector<SILArgument *, 4> indirectResults;
thunkSGF.collectThunkParams(loc, params, &indirectResults);

auto *derivativeFnRef = thunkSGF.B.createFunctionRef(loc, derivativeFn);
auto derivativeFnRefType =
derivativeFnRef->getType().castTo<SILFunctionType>();
auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn);
auto fnRefType =
fnRef->getType().castTo<SILFunctionType>();

// Collect thunk arguments, converting ownership.
SmallVector<SILValue, 8> arguments;
for (auto *indRes : indirectResults)
arguments.push_back(indRes);
forwardFunctionArguments(thunkSGF, loc, derivativeFnRefType, params,
forwardFunctionArguments(thunkSGF, loc, fnRefType, params,
arguments);
// Apply function argument.
auto apply = thunkSGF.emitApplyWithRethrow(
loc, derivativeFnRef, /*substFnType*/ derivativeFnRef->getType(),
loc, fnRef, /*substFnType*/ fnRef->getType(),
thunk->getForwardingSubstitutionMap(), arguments);

// Create return instruction in the thunk, first deallocating local
Expand All @@ -3787,15 +3786,27 @@ SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
thunkSGF.B.createReturn(loc, retValue);
};

// Self reordering thunk is necessary if wrt at least two parameters,
// including self.
auto shouldReorderSelf = [&]() {
if (!originalFn->hasSelfParam())
return false;
auto selfParamIndex = origFnTy->getNumParameters() - 1;
if (!indices.isWrtParameter(selfParamIndex))
return false;
return indices.parameters->getNumIndices() > 1;
};
bool reorderSelf = shouldReorderSelf();

// If self ordering is not necessary and linear map types are unchanged,
// return the `apply` instruction.
auto linearMapFnType = cast<SILFunctionType>(
thunk
->mapTypeIntoContext(
derivativeFnRefType->getResults().back().getInterfaceType())
fnRefType->getResults().back().getInterfaceType())
->getCanonicalType());
auto targetLinearMapFnType = thunk->mapTypeIntoContext(
origDerivativeFnType->getResults().back().getSILStorageInterfaceType())
thunkFnTy->getResults().back().getSILStorageInterfaceType())
.castTo<SILFunctionType>();
if (!reorderSelf && linearMapFnType == targetLinearMapFnType) {
createReturn(apply);
Expand All @@ -3807,7 +3818,7 @@ SILFunction *SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
extractAllElements(apply, loc, thunkSGF.B, directResults);
auto linearMap = thunkSGF.emitManagedRValueWithCleanup(directResults.back());
assert(linearMap.getType().castTo<SILFunctionType>() == linearMapFnType);
auto linearMapKind = derivativeFnKind.getLinearMapKind();
auto linearMapKind = kind.getLinearMapKind();
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
linearMap, linearMapKind, linearMapFnType, targetLinearMapFnType,
reorderSelf);
Expand Down
41 changes: 0 additions & 41 deletions lib/SILGen/SILGenThunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,47 +70,6 @@ SILFunction *SILGenModule::getDynamicThunk(SILDeclRef constant,
return F;
}

// SWIFT_ENABLE_TENSORFLOW
SILFunction *
SILGenModule::getOrCreateAutoDiffDerivativeForwardingThunk(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe as is, getOrCreateAutoDiffDerivativeForwardingThunk created a more efficient thunk, as it never reabstracts the returned linear map while getOrCreateCustomDerivativeThunk always does. Not a big deal!

SILDeclRef derivativeFnDeclRef, SILFunction *derivativeFn,
CanSILFunctionType derivativeFnTy) {
auto *autoDiffFuncId =
derivativeFnDeclRef.autoDiffDerivativeFunctionIdentifier;
assert(autoDiffFuncId);
auto *derivativeFnDecl = derivativeFnDeclRef.getDecl();

SILGenFunctionBuilder builder(*this);
auto originalFn = derivativeFnDeclRef.asAutoDiffOriginalFunction();
auto name = derivativeFnDeclRef.mangle();
// This thunk is publicly exposed and cannot be transparent.
// Instead, mark it as "always inline" for optimization.
auto *thunk = builder.getOrCreateFunction(
derivativeFnDecl, name, originalFn.getLinkage(ForDefinition),
derivativeFnTy, IsBare, IsNotTransparent,
derivativeFnDeclRef.isSerialized(), IsNotDynamic, ProfileCounter(),
IsThunk);
thunk->setInlineStrategy(AlwaysInline);
if (!thunk->empty())
return thunk;

if (auto genSig = derivativeFnTy->getSubstGenericSignature())
thunk->setGenericEnvironment(genSig->getGenericEnvironment());
SILGenFunction SGF(*this, *thunk, SwiftModule);
SmallVector<ManagedValue, 4> params;
auto loc = derivativeFnDeclRef.getAsRegularLocation();
SGF.collectThunkParams(loc, params);
auto derivativeFnRef = SGF.B.createFunctionRef(loc, derivativeFn);
auto autoDiffDerivativeFnSILTy = SILType::getPrimitiveObjectType(derivativeFnTy);
SmallVector<SILValue, 4> args(thunk->getArguments().begin(),
thunk->getArguments().end());
auto apply = SGF.emitApplyWithRethrow(
loc, derivativeFnRef, autoDiffDerivativeFnSILTy,
SGF.getForwardingSubstitutionMap(), args);
SGF.B.createReturn(loc, apply);
return thunk;
}

// SWIFT_ENABLE_TENSORFLOW
SILFunction *SILGenModule::getOrCreateAutoDiffClassMethodThunk(
SILDeclRef derivativeFnDeclRef, CanSILFunctionType constantTy) {
Expand Down
Loading