Skip to content

[AutoDiff upstream] Add derivative function type calculation. #29218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3190,6 +3190,76 @@ class AnyFunctionType : public TypeBase {
return getExtInfo().getRepresentation();
}

/// Returns the derivative function type for the given parameter indices,
/// result index, derivative function kind, derivative function generic
/// signature (optional), and other auxiliary parameters.
///
/// Preconditions:
/// - Parameters corresponding to parameter indices must conform to
/// `Differentiable`.
/// - The result corresponding to the result index must conform to
/// `Differentiable`.
///
/// Typing rules, given:
/// - Original function type. Three cases:
/// - Top-level function: `(T0, T1, ...) -> R`
/// - Static method: `(Self.Type) -> (T0, T1, ...) -> R`
/// - Instance method: `(Self) -> (T0, T1, ...) -> R`
///
/// Terminology:
/// - The derivative of a `Differentiable`-conforming type has the
/// `TangentVector` associated type. `TangentVector` is abbreviated as `Tan`
/// below.
/// - "wrt" parameters refers to differentiability parameters, identified by
/// the parameter indices.
/// - "wrt" result refers to the result identified by the result index.
///
/// JVP derivative type:
/// - Takes original parameters.
/// - Returns original result, followed by a differential function, which
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
///
/// \verbatim
/// (T0, T1, ...) -> (R, (T0.Tan, T1.Tan, ...) -> R.Tan)
/// ^ ^~~~~~~~~~~~~~~~~~~ ^~~~~
/// original result | derivatives wrt params | derivative wrt result
///
/// (Self) -> (T0, ...) -> (R, (Self.Tan, T0.Tan, ...) -> R.Tan)
/// ^ ^~~~~~~~~~~~~~~~~~~~~ ^~~~~
/// original result | deriv. wrt params | deriv. wrt result
/// \endverbatim
///
/// VJP derivative type:
/// - Takes original parameters.
/// - Returns original result, followed by a pullback function, which
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
///
/// \verbatim
/// (T0, T1, ...) -> (R, (R.Tan) -> (T0.Tan, T1.Tan, ...))
/// ^ ^~~~~ ^~~~~~~~~~~~~~~~~~~
/// original result | derivative wrt result | derivatives wrt params
///
/// (Self) -> (T0, ...) -> (R, (R.Tan) -> (Self.Tan, T0.Tan, ...))
/// ^ ^~~~~ ^~~~~~~~~~~~~~~~~~~~~
/// original result | deriv. wrt result | deriv. wrt params
/// \endverbatim
///
/// By default, if the original type has a `self` parameter list and parameter
/// indices include `self`, the computed derivative function type will return
/// a linear map taking/returning self's tangent *last* instead of first, for
/// consistency with SIL.
///
/// If `makeSelfParamFirst` is true, `self`'s tangent is reordered to appear
/// first. `makeSelfParamFirst` should be true when working with user-facing
/// derivative function types, e.g. when type-checking `@differentiable` and
/// `@derivative` attributes.
AnyFunctionType *getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind,
LookupConformanceFn lookupConformance,
GenericSignature derivativeGenericSignature = GenericSignature(),
bool makeSelfParamFirst = false);

/// True if the parameter declaration it is attached to is guaranteed
/// to not persist the closure for longer than the duration of the call.
bool isNoEscape() const {
Expand Down
128 changes: 128 additions & 0 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4910,6 +4910,134 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
return cache(None);
}

// Creates an `AnyFunctionType` from the given parameters, result type,
// generic signature, and `ExtInfo`.
static AnyFunctionType *
makeFunctionType(ArrayRef<AnyFunctionType::Param> parameters, Type resultType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a general utility for AST function types. Consider adding it to AnyFunctionType, or just making it a closure inside getAutoDiffDerivativeFunctionType.

GenericSignature genericSignature,
AnyFunctionType::ExtInfo extInfo) {
if (genericSignature)
return GenericFunctionType::get(genericSignature, parameters, resultType,
extInfo);
return FunctionType::get(parameters, resultType, extInfo);
}

AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind, LookupConformanceFn lookupConformance,
GenericSignature derivativeGenSig, bool makeSelfParamFirst) {
assert(!parameterIndices->isEmpty() &&
"Expected at least one differentiability parameter");
auto &ctx = getASTContext();

// If `derivativeGenSig` is not defined, use the current function's type
// generic signature.
if (!derivativeGenSig)
derivativeGenSig = getOptGenericSignature();

// Get differentiability parameter types.
SmallVector<Type, 8> diffParamTypes;
autodiff::getSubsetParameterTypes(parameterIndices, this, diffParamTypes,
/*reverseCurryLevels*/ !makeSelfParamFirst);

// Unwrap curry levels. At most, two parameter lists are necessary, for
// curried method types with a `(Self)` parameter list.
// TODO(TF-874): Simplify curry level logic.
SmallVector<AnyFunctionType *, 2> curryLevels;
auto *currentLevel = castTo<AnyFunctionType>();
for (unsigned i : range(2)) {
(void)i;
if (currentLevel == nullptr)
break;
curryLevels.push_back(currentLevel);
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
}

Type originalResult = curryLevels.back()->getResult();

// Build the result linear map function type.
Type linearMapType;
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP: {
// Differential function type, a result of the JVP:
// `LinearMapType = (T.TangentVector, ...) -> (R.TangentVector)`
SmallVector<AnyFunctionType::Param, 8> differentialParams;
for (auto diffParamType : diffParamTypes)
differentialParams.push_back(AnyFunctionType::Param(
diffParamType->getAutoDiffTangentSpace(lookupConformance)
->getType()));
SmallVector<TupleTypeElt, 8> differentialResults;
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
auto resultTupleEltType = resultTuple->getElementType(resultIndex);
differentialResults.push_back(
resultTupleEltType->getAutoDiffTangentSpace(lookupConformance)
->getType());
} else {
assert(resultIndex == 0 && "resultIndex out of bounds");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: this should match the assertion below:

      assert(resultIndex == 0 &&
             "Expected result index 0 for non-tuple result");

differentialResults.push_back(
originalResult->getAutoDiffTangentSpace(lookupConformance)
->getType());
}
Type differentialResult = differentialResults.size() > 1
? TupleType::get(differentialResults, ctx)
: differentialResults[0].getType();
linearMapType = FunctionType::get(differentialParams, differentialResult);
break;
}
case AutoDiffDerivativeFunctionKind::VJP: {
// Pullback function type, a result of the VJP:
// `LinearMapType = (R.TangentVector) -> (T.TangentVector, ...)`
SmallVector<AnyFunctionType::Param, 8> pullbackParams;
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
auto resultTupleEltType = resultTuple->getElementType(resultIndex);
pullbackParams.push_back(AnyFunctionType::Param(
resultTupleEltType->getAutoDiffTangentSpace(lookupConformance)
->getType()));
} else {
assert(resultIndex == 0 &&
"Expected result index 0 for non-tuple result");
pullbackParams.push_back(AnyFunctionType::Param(
originalResult->getAutoDiffTangentSpace(lookupConformance)
->getType()));
}
SmallVector<TupleTypeElt, 8> pullbackResults;
for (auto diffParamType : diffParamTypes)
pullbackResults.push_back(
diffParamType->getAutoDiffTangentSpace(lookupConformance)->getType());
Type pullbackResult = pullbackResults.size() > 1
? TupleType::get(pullbackResults, ctx)
: pullbackResults[0].getType();
linearMapType = FunctionType::get(pullbackParams, pullbackResult);
break;
}
}
assert(linearMapType && "Expected linear map type");

// Build the full derivative function type: `(T...) -> (R, LinearMapType)`.
SmallVector<TupleTypeElt, 2> retElts;
retElts.push_back(originalResult);
retElts.push_back(linearMapType);
auto retTy = TupleType::get(retElts, ctx);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: I prefer using the name "result type" instead of "return type". retTy -> resultTy

auto *derivativeFunctionType =
makeFunctionType(curryLevels.back()->getParams(), retTy,
curryLevels.size() == 1 ? derivativeGenSig : nullptr,
curryLevels.back()->getExtInfo());

// Wrap the derivative function type in additional curry levels.
auto curryLevelsWithoutLast =
ArrayRef<AnyFunctionType *>(curryLevels).drop_back(1);
for (auto pair : enumerate(llvm::reverse(curryLevelsWithoutLast))) {
unsigned i = pair.index();
auto *curryLevel = pair.value();
derivativeFunctionType = makeFunctionType(
curryLevel->getParams(), derivativeFunctionType,
i == curryLevelsWithoutLast.size() - 1 ? derivativeGenSig : nullptr,
curryLevel->getExtInfo());
}

return derivativeFunctionType;
}

CanSILFunctionType
SILFunctionType::withSubstitutions(SubstitutionMap subs) const {
return SILFunctionType::get(getSubstGenericSignature(),
Expand Down