Skip to content

[AutoDiff upstream] Add differentiability witness SILGen. #30545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,25 @@ class ASTMangler : public Mangler {
Type SelfType,
ModuleDecl *Module);

/// Mangle the derivative function (JVP/VJP) for the given:
/// - Mangled original function name.
/// - Derivative function kind.
/// - Derivative function configuration: parameter/result indices and
/// derivative generic signature.
std::string
mangleAutoDiffDerivativeFunctionHelper(StringRef name,
AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config);

/// Mangle the linear map (differential/pullback) for the given:
/// - Mangled original function name.
/// - Linear map kind.
/// - Derivative function configuration: parameter/result indices and
/// derivative generic signature.
std::string mangleAutoDiffLinearMapHelper(StringRef name,
AutoDiffLinearMapKind kind,
AutoDiffConfig config);

/// Mangle a SIL differentiability witness key:
/// - Mangled original function name.
/// - Parameter indices.
Expand Down
76 changes: 76 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "swift/AST/TypeAlignments.h"
#include "swift/Basic/Range.h"
#include "swift/Basic/SourceLoc.h"
#include "llvm/ADT/StringExtras.h"

namespace swift {

Expand Down Expand Up @@ -95,6 +96,45 @@ struct DifferentiabilityWitnessFunctionKind {
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
};

/// SIL-level automatic differentiation indices. Consists of:
/// - Parameter indices: indices of parameters to differentiate with respect to.
/// - Result index: index of the result to differentiate from.
// TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
// `AutoDiffConfig` supports multiple result indices.
struct SILAutoDiffIndices {
/// The index of the dependent result to differentiate from.
unsigned source;
/// The indices for independent parameters to differentiate with respect to.
IndexSubset *parameters;

/*implicit*/ SILAutoDiffIndices(unsigned source, IndexSubset *parameters)
: source(source), parameters(parameters) {}

bool operator==(const SILAutoDiffIndices &other) const;

bool operator!=(const SILAutoDiffIndices &other) const {
return !(*this == other);
};

/// Returns true if `parameterIndex` is a differentiability parameter index.
bool isWrtParameter(unsigned parameterIndex) const {
return parameterIndex < parameters->getCapacity() &&
parameters->contains(parameterIndex);
}

void print(llvm::raw_ostream &s = llvm::outs()) const;
SWIFT_DEBUG_DUMP;

std::string mangle() const {
std::string result = "src_" + llvm::utostr(source) + "_wrt_";
interleave(
parameters->getIndices(),
[&](unsigned idx) { result += llvm::utostr(idx); },
[&] { result += '_'; });
return result;
}
};

/// Identifies an autodiff derivative function configuration:
/// - Parameter indices.
/// - Result indices.
Expand All @@ -110,6 +150,11 @@ struct AutoDiffConfig {
: parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenericSignature) {}

/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
// TODO(TF-913): This is a temporary shim for incremental removal of
// `SILAutoDiffIndices`. Eventually remove this.
SILAutoDiffIndices getSILAutoDiffIndices() const;

void print(llvm::raw_ostream &s = llvm::outs()) const;
SWIFT_DEBUG_DUMP;
};
Expand Down Expand Up @@ -282,6 +327,37 @@ void getFunctionSemanticResultTypes(
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
GenericEnvironment *genericEnv = nullptr);

/// Returns the lowered SIL parameter indices for the given AST parameter
/// indices and `AnyfunctionType`.
///
/// Notable lowering-related changes:
/// - AST tuple parameter types are exploded when lowered to SIL.
/// - AST curried `Self` parameter types become the last parameter when lowered
/// to SIL.
///
/// Examples:
///
/// AST function type: (A, B, C) -> R
/// AST parameter indices: 101, {A, C}
/// Lowered SIL function type: $(A, B, C) -> R
/// Lowered SIL parameter indices: 101
///
/// AST function type: (Self) -> (A, B, C) -> R
/// AST parameter indices: 1010, {Self, B}
/// Lowered SIL function type: $(A, B, C, Self) -> R
/// Lowered SIL parameter indices: 0101
///
/// AST function type: (A, (B, C), D) -> R
/// AST parameter indices: 110, {A, (B, C)}
/// Lowered SIL function type: $(A, B, C, D) -> R
/// Lowered SIL parameter indices: 1110
///
/// Note:
/// - The AST function type must not be curried unless it is a method.
/// Otherwise, the behavior is undefined.
IndexSubset *getLoweredParameterIndices(IndexSubset *astParameterIndices,
AnyFunctionType *functionType);

/// "Constrained" derivative generic signatures require all differentiability
/// parameters to conform to the `Differentiable` protocol.
///
Expand Down
51 changes: 51 additions & 0 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,57 @@ std::string ASTMangler::mangleReabstractionThunkHelper(
return finalize();
}

std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper(
StringRef name, AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config) {
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
beginManglingWithoutPrefix();

Buffer << "AD__" << name << '_';
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP:
Buffer << "_jvp_";
break;
case AutoDiffDerivativeFunctionKind::VJP:
Buffer << "_vjp_";
break;
}
Buffer << config.getSILAutoDiffIndices().mangle();
if (config.derivativeGenericSignature) {
Buffer << '_';
appendGenericSignature(config.derivativeGenericSignature);
}

auto result = Storage.str().str();
Storage.clear();
return result;
}

std::string ASTMangler::mangleAutoDiffLinearMapHelper(
StringRef name, AutoDiffLinearMapKind kind, AutoDiffConfig config) {
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
beginManglingWithoutPrefix();

Buffer << "AD__" << name << '_';
switch (kind) {
case AutoDiffLinearMapKind::Differential:
Buffer << "_differential_";
break;
case AutoDiffLinearMapKind::Pullback:
Buffer << "_pullback_";
break;
}
Buffer << config.getSILAutoDiffIndices().mangle();
if (config.derivativeGenericSignature) {
Buffer << '_';
appendGenericSignature(config.derivativeGenericSignature);
}

auto result = Storage.str().str();
Storage.clear();
return result;
}

std::string ASTMangler::mangleSILDifferentiabilityWitnessKey(
SILDifferentiabilityWitnessKey key) {
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
Expand Down
54 changes: 54 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
}
}

void SILAutoDiffIndices::print(llvm::raw_ostream &s) const {
s << "(source=" << source << " parameters=(";
interleave(
parameters->getIndices(), [&s](unsigned p) { s << p; },
[&s] { s << ' '; });
s << "))";
}

void SILAutoDiffIndices::dump() const {
print(llvm::errs());
llvm::errs() << '\n';
}

SILAutoDiffIndices AutoDiffConfig::getSILAutoDiffIndices() const {
assert(resultIndices->getNumIndices() == 1);
return SILAutoDiffIndices(*resultIndices->begin(), parameterIndices);
}

void AutoDiffConfig::print(llvm::raw_ostream &s) const {
s << "(parameters=";
parameterIndices->print(s);
Expand Down Expand Up @@ -138,6 +156,42 @@ void autodiff::getFunctionSemanticResultTypes(
}
}

// TODO(TF-874): Simplify this helper. See TF-874 for WIP.
IndexSubset *
autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices,
AnyFunctionType *functionType) {
SmallVector<AnyFunctionType *, 2> curryLevels;
unwrapCurryLevels(functionType, curryLevels);

// Compute the lowered sizes of all AST parameter types.
SmallVector<unsigned, 8> paramLoweredSizes;
unsigned totalLoweredSize = 0;
auto addLoweredParamInfo = [&](Type type) {
unsigned paramLoweredSize = countNumFlattenedElementTypes(type);
paramLoweredSizes.push_back(paramLoweredSize);
totalLoweredSize += paramLoweredSize;
};
for (auto *curryLevel : llvm::reverse(curryLevels))
for (auto &param : curryLevel->getParams())
addLoweredParamInfo(param.getPlainType());

// Build lowered SIL parameter indices by setting the range of bits that
// corresponds to each "set" AST parameter.
llvm::SmallVector<unsigned, 8> loweredSILIndices;
unsigned currentBitIndex = 0;
for (unsigned i : range(parameterIndices->getCapacity())) {
auto paramLoweredSize = paramLoweredSizes[i];
if (parameterIndices->contains(i)) {
auto indices = range(currentBitIndex, currentBitIndex + paramLoweredSize);
loweredSILIndices.append(indices.begin(), indices.end());
}
currentBitIndex += paramLoweredSize;
}

return IndexSubset::get(functionType->getASTContext(), totalLoweredSize,
loweredSILIndices);
}

GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
Expand Down
126 changes: 126 additions & 0 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,132 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
LLVM_DEBUG(llvm::dbgs() << "lowered sil:\n";
F->print(llvm::dbgs()));
F->verify();

emitDifferentiabilityWitnessesForFunction(constant, F);
}

void SILGenModule::emitDifferentiabilityWitnessesForFunction(
SILDeclRef constant, SILFunction *F) {
// Visit `@differentiable` amd `@derivative` attributes and generate SIL
// differentiability witnesses.
// Skip if the SILDeclRef is a:
// - Default argument generator function.
// - Thunk.
if (!constant.hasDecl() || !constant.getAbstractFunctionDecl())
return;
if (constant.kind == SILDeclRef::Kind::DefaultArgGenerator ||
constant.isThunk())
return;
auto *AFD = constant.getAbstractFunctionDecl();
auto emitWitnesses = [&](DeclAttributes &Attrs) {
for (auto *diffAttr : Attrs.getAttributes<DifferentiableAttr>()) {
SILFunction *jvp = nullptr;
SILFunction *vjp = nullptr;
if (auto *jvpDecl = diffAttr->getJVPFunction())
jvp = getFunction(SILDeclRef(jvpDecl), ForDefinition);
if (auto *vjpDecl = diffAttr->getVJPFunction())
vjp = getFunction(SILDeclRef(vjpDecl), ForDefinition);
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
assert((!F->getLoweredFunctionType()->getSubstGenericSignature() ||
diffAttr->getDerivativeGenericSignature()) &&
"Type-checking should resolve derivative generic signatures for "
"all original SIL functions with generic signatures");
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
diffAttr->getDerivativeGenericSignature());
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
}
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
SILFunction *jvp = nullptr;
SILFunction *vjp = nullptr;
switch (derivAttr->getDerivativeKind()) {
case AutoDiffDerivativeFunctionKind::JVP:
jvp = F;
break;
case AutoDiffDerivativeFunctionKind::VJP:
vjp = F;
break;
}
auto *origAFD = derivAttr->getOriginalFunction();
auto origDeclRef =
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
auto *origFn = getFunction(origDeclRef, NotForDefinition);
auto derivativeGenSig = AFD->getGenericSignature();
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
derivativeGenSig);
emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp,
derivAttr);
}
};
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
if (accessor->isGetter())
emitWitnesses(accessor->getStorage()->getAttrs());
emitWitnesses(AFD->getAttrs());
}

void SILGenModule::emitDifferentiabilityWitness(
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp,
const DeclAttribute *attr) {
assert(isa<DifferentiableAttr>(attr) || isa<DerivativeAttr>(attr));
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
auto origSilFnType = originalFunction->getLoweredFunctionType();
auto *silParamIndices =
autodiff::getLoweredParameterIndices(config.parameterIndices, origFnType);
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
// parameters corresponding to captured variables. These parameters do not
// appear in the type of `origFnType`.
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
// take `CaptureInfo` into account.
if (origSilFnType->getNumParameters() > silParamIndices->getCapacity())
silParamIndices = silParamIndices->extendingCapacity(
getASTContext(), origSilFnType->getNumParameters());

// Get or create new SIL differentiability witness.
// Witness already exists when there are two `@derivative` attributes
// (registering JVP and VJP functions) for the same derivative function
// configuration.
// Witness JVP and VJP are set below.
AutoDiffConfig silConfig(silParamIndices, config.resultIndices,
config.derivativeGenericSignature);
SILDifferentiabilityWitnessKey key{originalFunction->getName(), silConfig};
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
if (!diffWitness) {
// Strip external from linkage of original function.
// Necessary for Clang-imported functions, which have external linkage.
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
diffWitness = SILDifferentiabilityWitness::createDefinition(
M, linkage, originalFunction, silConfig.parameterIndices,
silConfig.resultIndices, config.derivativeGenericSignature,
/*jvp*/ nullptr, /*vjp*/ nullptr,
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
attr);
}

// Set derivative function in differentiability witness.
auto setDerivativeInDifferentiabilityWitness =
[&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
auto derivativeThunk = getOrCreateCustomDerivativeThunk(
derivative, originalFunction, silConfig, kind);
// Check for existing same derivative.
// TODO(TF-835): Remove condition below and simplify assertion to
// `!diffWitness->getDerivative(kind)` after `@derivative` attribute
// type-checking no longer generates implicit `@differentiable`
// attributes.
auto *existingDerivative = diffWitness->getDerivative(kind);
if (existingDerivative && existingDerivative == derivativeThunk)
return;
assert(!existingDerivative &&
"SIL differentiability witness already has a different existing "
"derivative");
diffWitness->setDerivative(kind, derivativeThunk);
};
if (jvp)
setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::JVP,
jvp);
if (vjp)
setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::VJP,
vjp);
}

void SILGenModule::
Expand Down
Loading