Skip to content

[AutoDiff upstream] Add SIL derivative function type calculation. #29396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4390,6 +4390,89 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,

const clang::FunctionType *getClangFunctionType() const;

/// Returns the type of the derivative function for the given parameter
/// indices, result index, derivative function kind, derivative function
/// generic signature (optional), and other auxiliary parameters.
///
/// Preconditions:
/// - Parameters corresponding to parameter indices must conform to
/// `Differentiable`.
/// - The result corresponding to the result index must conform to
/// `Differentiable`.
///
/// Typing rules, given:
/// - Original function type: $(T0, T1, ...) -> (R0, R1, ...)
///
/// Terminology:
/// - The derivative of a `Differentiable`-conforming type has the
/// `TangentVector` associated type. `TangentVector` is abbreviated as `Tan`
/// below.
/// - "wrt" parameters refers to parameters indicated by the parameter
/// indices.
/// - "wrt" result refers to the result indicated by the result index.
///
/// JVP derivative type:
/// - Takes original parameters.
/// - Returns original results, followed by a differential function, which
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
///
/// $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
/// ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
/// original results | derivatives wrt params | derivative wrt result
///
/// VJP derivative type:
/// - Takes original parameters.
/// - Returns original results, followed by a pullback function, which
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
///
/// $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
/// ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
/// original results | derivative wrt result | derivatives wrt params
///
/// A "constrained derivative generic signature" is computed from
/// `derivativeFunctionGenericSignature`, if specified. Otherwise, it is
/// computed from the original generic signature. A "constrained derivative
/// generic signature" requires all "wrt" parameters to conform to
/// `Differentiable`; this is important for correctness.
///
/// This "constrained derivative generic signature" is used for
/// parameter/result type lowering. It is used as the actual generic signature
/// of the derivative function type iff the original function type has a
/// generic signature and not all generic parameters are bound to concrete
/// types. Otherwise, no derivative generic signature is used.
///
/// Other properties of the original function type are copied exactly:
/// `ExtInfo`, coroutine kind, callee convention, yields, optional error
/// result, witness method conformance, etc.
///
/// Special cases:
/// - Reabstraction thunks have special derivative type calculation. The
/// original function-typed last parameter is transformed into a
/// `@differentiable` function-typed parameter in the derivative type. This
/// is necessary for the differentiation transform to support reabstraction
/// thunk differentiation because the function argument is opaque and cannot
/// be differentiated. Instead, the argument is made `@differentiable` and
/// reabstraction thunk JVP/VJP callers are responsible for passing a
/// `@differentiable` function.
/// - TODO(TF-1036): Investigate more efficient reabstraction thunk
/// derivative approaches. The last argument can simply be a
/// corresponding derivative function, instead of a `@differentiable`
/// function - this is more direct. It may be possible to implement
/// reabstraction thunk derivatives using "reabstraction thunks for
/// the original function's derivative", avoiding extra code generation.
///
/// Caveats:
/// - We may support multiple result indices instead of a single result index
/// eventually. At the SIL level, this enables differentiating wrt multiple
/// function results. At the Swift level, this enables differentiating wrt
/// multiple tuple elements for tuple-returning functions.
CanSILFunctionType getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFunctionGenericSignature = nullptr,
bool isReabstractionThunk = false);

ExtInfo getExtInfo() const {
return ExtInfo(Bits.SILFunctionType.ExtInfoBits, getClangFunctionType());
}
Expand Down
191 changes: 191 additions & 0 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "swift/AST/DiagnosticsSIL.h"
#include "swift/AST/ForeignInfo.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/Module.h"
#include "swift/AST/ModuleLoader.h"
#include "swift/AST/ProtocolConformance.h"
Expand Down Expand Up @@ -190,6 +191,196 @@ SILFunctionType::getWitnessMethodClass(SILModule &M) const {
return nullptr;
}

// Returns the canonical generic signature for an autodiff derivative function
// given an existing derivative function generic signature. All
// differentiability parameters are required to conform to `Differentiable`.
static CanGenericSignature getAutoDiffDerivativeFunctionGenericSignature(
CanGenericSignature derivativeFnGenSig,
ArrayRef<SILParameterInfo> originalParameters,
IndexSubset *parameterIndices, ModuleDecl *module) {
if (!derivativeFnGenSig)
return nullptr;
auto &ctx = module->getASTContext();
GenericSignatureBuilder builder(ctx);
// Add derivative function generic signature.
builder.addGenericSignature(derivativeFnGenSig);
// All differentiability parameters are required to conform to
// `Differentiable`.
auto source =
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
auto *differentiableProtocol =
ctx.getProtocol(KnownProtocolKind::Differentiable);
for (unsigned paramIdx : parameterIndices->getIndices()) {
auto paramType = originalParameters[paramIdx].getInterfaceType();
Requirement req(RequirementKind::Conformance, paramType,
differentiableProtocol->getDeclaredType());
builder.addRequirement(req, source, module);
}
return std::move(builder)
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams*/ true)
->getCanonicalSignature();
}

CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFnGenSig, bool isReabstractionThunk) {
auto &ctx = getASTContext();

// Returns true if `index` is a differentiability parameter index.
auto isDiffParamIndex = [&](unsigned index) -> bool {
return index < parameterIndices->getCapacity() &&
parameterIndices->contains(index);
};

// Calculate differentiability parameter infos.
SmallVector<SILParameterInfo, 4> diffParams;
for (auto valueAndIndex : enumerate(getParameters()))
if (isDiffParamIndex(valueAndIndex.index()))
diffParams.push_back(valueAndIndex.value());

// Get the canonical derivative function generic signature.
if (!derivativeFnGenSig)
derivativeFnGenSig = getSubstGenericSignature();
derivativeFnGenSig = getAutoDiffDerivativeFunctionGenericSignature(
derivativeFnGenSig, getParameters(), parameterIndices, &TC.M);

// Given a type, returns its formal SIL parameter info.
auto getTangentParameterInfoForOriginalResult =
[&](CanType tanType, ResultConvention origResConv) -> SILParameterInfo {
AbstractionPattern pattern(derivativeFnGenSig, tanType);
auto &tl =
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
ParameterConvention conv;
switch (origResConv) {
case ResultConvention::Owned:
case ResultConvention::Autoreleased:
conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
: ParameterConvention::Direct_Guaranteed;
break;
case ResultConvention::Unowned:
case ResultConvention::UnownedInnerPointer:
conv = ParameterConvention::Direct_Unowned;
break;
case ResultConvention::Indirect:
conv = ParameterConvention::Indirect_In_Guaranteed;
break;
}
return {tanType, conv};
};

// Given a type, returns its formal SIL result info.
auto getTangentResultInfoForOriginalParameter =
[&](CanType tanType, ParameterConvention origParamConv) -> SILResultInfo {
AbstractionPattern pattern(derivativeFnGenSig, tanType);
auto &tl =
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
ResultConvention conv;
switch (origParamConv) {
case ParameterConvention::Direct_Owned:
case ParameterConvention::Direct_Guaranteed:
case ParameterConvention::Direct_Unowned:
conv =
tl.isTrivial() ? ResultConvention::Unowned : ResultConvention::Owned;
break;
case ParameterConvention::Indirect_In:
case ParameterConvention::Indirect_Inout:
case ParameterConvention::Indirect_In_Constant:
case ParameterConvention::Indirect_In_Guaranteed:
case ParameterConvention::Indirect_InoutAliasable:
conv = ResultConvention::Indirect;
break;
}
return {tanType, conv};
};

CanSILFunctionType closureType;
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP: {
SmallVector<SILParameterInfo, 8> differentialParams;
for (auto &param : diffParams) {
auto paramTan =
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
differentialParams.push_back(
{paramTan->getCanonicalType(), param.getConvention()});
}
SmallVector<SILResultInfo, 8> differentialResults;
auto &result = getResults()[resultIndex];
auto resultTan =
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(resultTan && "Result type does not have a tangent space?");
differentialResults.push_back(
{resultTan->getCanonicalType(), result.getConvention()});
closureType = SILFunctionType::get(
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
ParameterConvention::Direct_Guaranteed, differentialParams, {},
differentialResults, None, getSubstitutions(),
isGenericSignatureImplied(), ctx);
break;
}
case AutoDiffDerivativeFunctionKind::VJP: {
SmallVector<SILParameterInfo, 8> pullbackParams;
auto &origRes = getResults()[resultIndex];
auto resultTan =
origRes.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(resultTan && "Result type does not have a tangent space?");
pullbackParams.push_back(getTangentParameterInfoForOriginalResult(
resultTan->getCanonicalType(), origRes.getConvention()));
SmallVector<SILResultInfo, 8> pullbackResults;
for (auto &param : diffParams) {
auto paramTan =
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
pullbackResults.push_back(getTangentResultInfoForOriginalParameter(
paramTan->getCanonicalType(), param.getConvention()));
}
closureType = SILFunctionType::get(
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
ParameterConvention::Direct_Guaranteed, pullbackParams, {},
pullbackResults, {}, getSubstitutions(), isGenericSignatureImplied(),
ctx);
break;
}
}

SmallVector<SILParameterInfo, 4> newParameters;
newParameters.reserve(getNumParameters());
for (auto &param : getParameters()) {
newParameters.push_back(param.getWithInterfaceType(
param.getInterfaceType()->getCanonicalType(derivativeFnGenSig)));
}
// TODO(TF-1124): Upstream reabstraction thunk derivative typing rules.
// Blocked by TF-1125: `SILFunctionType::getWithDifferentiability`.
SmallVector<SILResultInfo, 4> newResults;
newResults.reserve(getNumResults() + 1);
for (auto &result : getResults()) {
newResults.push_back(result.getWithInterfaceType(
result.getInterfaceType()->getCanonicalType(derivativeFnGenSig)));
}
newResults.push_back({closureType->getCanonicalType(derivativeFnGenSig),
ResultConvention::Owned});
// Derivative function type has a generic signature only if the original
// function type does, and if `derivativeFnGenSig` does not have all concrete
// generic parameters.
CanGenericSignature canGenSig;
if (getSubstGenericSignature() && derivativeFnGenSig &&
!derivativeFnGenSig->areAllParamsConcrete())
canGenSig = derivativeFnGenSig;
// If original function is `@convention(c)`, the derivative function should
// have `@convention(thin)`. IRGen does not support `@convention(c)` functions
// with multiple results.
auto extInfo = getExtInfo();
if (getRepresentation() == SILFunctionTypeRepresentation::CFunctionPointer)
extInfo = extInfo.withRepresentation(SILFunctionTypeRepresentation::Thin);
return SILFunctionType::get(canGenSig, extInfo, getCoroutineKind(),
getCalleeConvention(), newParameters, getYields(),
newResults, getOptionalErrorResult(),
getSubstitutions(), isGenericSignatureImplied(),
ctx, getWitnessMethodConformanceOrInvalid());
}

static CanType getKnownType(Optional<CanType> &cacheSlot, ASTContext &C,
StringRef moduleName, StringRef typeName) {
if (!cacheSlot) {
Expand Down