Skip to content

Commit 732e0f4

Browse files
committed
[AutoDiff upstream] Add derivative function type calculation.
Add `AnyFunctionType::getAutoDiffDerivativeFunctionType`, which returns the derivative `AnyFunctionType` for an "original" `AnyFunctionType`, given: - Differentiability parameter indices - Differentiability result index - Derivative function kind - Derivative function generic signature (optional) - Other auxiliary parameters Add doc comments explaining typing rules, preconditions, and other details. Progress towards TF-828: upstream `@differentiable` attribute type-checking.
1 parent 29268e2 commit 732e0f4

File tree

2 files changed

+198
-0
lines changed

2 files changed

+198
-0
lines changed

include/swift/AST/Types.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3190,6 +3190,76 @@ class AnyFunctionType : public TypeBase {
31903190
return getExtInfo().getRepresentation();
31913191
}
31923192

3193+
/// Returns the derivative function type for the given parameter indices,
3194+
/// result index, derivative function kind, derivative function generic
3195+
/// signature (optional), and other auxiliary parameters.
3196+
///
3197+
/// Preconditions:
3198+
/// - Parameters corresponding to parameter indices must conform to
3199+
/// `Differentiable`.
3200+
/// - The result corresponding to the result index must conform to
3201+
/// `Differentiable`.
3202+
///
3203+
/// Typing rules, given:
3204+
/// - Original function type. Three cases:
3205+
/// - Top-level function: `(T0, T1, ...) -> R`
3206+
/// - Static method: `(Self.Type) -> (T0, T1, ...) -> R`
3207+
/// - Instance method: `(Self) -> (T0, T1, ...) -> R`
3208+
///
3209+
/// Terminology:
3210+
/// - The derivative of a `Differentiable`-conforming type has the
3211+
/// `TangentVector` associated type. `TangentVector` is abbreviated as `Tan`
3212+
/// below.
3213+
/// - "wrt" parameters refers to differentiability parameters, identified by
3214+
/// the parameter indices.
3215+
/// - "wrt" result refers to the result identified by the result index.
3216+
///
3217+
/// JVP derivative type:
3218+
/// - Takes original parameters.
3219+
/// - Returns original result, followed by a differential function, which
3220+
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
3221+
///
3222+
/// \verbatim
3223+
/// (T0, T1, ...) -> (R, (T0.Tan, T1.Tan, ...) -> R.Tan)
3224+
/// ^ ^~~~~~~~~~~~~~~~~~~ ^~~~~
3225+
/// original result | derivatives wrt params | derivative wrt result
3226+
///
3227+
/// (Self) -> (T0, ...) -> (R, (Self.Tan, T0.Tan, ...) -> R.Tan)
3228+
/// ^ ^~~~~~~~~~~~~~~~~~~~~ ^~~~~
3229+
/// original result | deriv. wrt params | deriv. wrt result
3230+
/// \endverbatim
3231+
///
3232+
/// VJP derivative type:
3233+
/// - Takes original parameters.
3234+
/// - Returns original result, followed by a pullback function, which
3235+
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
3236+
///
3237+
/// \verbatim
3238+
/// (T0, T1, ...) -> (R, (R.Tan) -> (T0.Tan, T1.Tan, ...))
3239+
/// ^ ^~~~~ ^~~~~~~~~~~~~~~~~~~
3240+
/// original result | derivative wrt result | derivatives wrt params
3241+
///
3242+
/// (Self) -> (T0, ...) -> (R, (R.Tan) -> (Self.Tan, T0.Tan, ...))
3243+
/// ^ ^~~~~ ^~~~~~~~~~~~~~~~~~~~~
3244+
/// original result | deriv. wrt result | deriv. wrt params
3245+
/// \endverbatim
3246+
///
3247+
/// By default, if the original type has a `self` parameter list and parameter
3248+
/// indices include `self`, the computed derivative function type will return
3249+
/// a linear map taking/returning self's tangent *last* instead of first, for
3250+
/// consistency with SIL.
3251+
///
3252+
/// If `makeSelfParamFirst` is true, `self`'s tangent is reordered to appear
3253+
/// first. `makeSelfParamFirst` should be true when working with user-facing
3254+
/// derivative function types, e.g. when type-checking `@differentiable` and
3255+
/// `@derivative` attributes.
3256+
AnyFunctionType *getAutoDiffDerivativeFunctionType(
3257+
IndexSubset *parameterIndices, unsigned resultIndex,
3258+
AutoDiffDerivativeFunctionKind kind,
3259+
LookupConformanceFn lookupConformance,
3260+
GenericSignature derivativeGenericSignature = GenericSignature(),
3261+
bool makeSelfParamFirst = false);
3262+
31933263
/// True if the parameter declaration it is attached to is guaranteed
31943264
/// to not persist the closure for longer than the duration of the call.
31953265
bool isNoEscape() const {

lib/AST/Type.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4910,6 +4910,134 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
49104910
return cache(None);
49114911
}
49124912

4913+
// Creates an `AnyFunctionType` from the given parameters, result type,
4914+
// generic signature, and `ExtInfo`.
4915+
static AnyFunctionType *
4916+
makeFunctionType(ArrayRef<AnyFunctionType::Param> parameters, Type resultType,
4917+
GenericSignature genericSignature,
4918+
AnyFunctionType::ExtInfo extInfo) {
4919+
if (genericSignature)
4920+
return GenericFunctionType::get(genericSignature, parameters, resultType,
4921+
extInfo);
4922+
return FunctionType::get(parameters, resultType, extInfo);
4923+
}
4924+
4925+
AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(
4926+
IndexSubset *parameterIndices, unsigned resultIndex,
4927+
AutoDiffDerivativeFunctionKind kind, LookupConformanceFn lookupConformance,
4928+
GenericSignature derivativeGenSig, bool makeSelfParamFirst) {
4929+
assert(!parameterIndices->isEmpty() &&
4930+
"Expected at least one differentiability parameter");
4931+
auto &ctx = getASTContext();
4932+
4933+
// If `derivativeGenSig` is not defined, use the current function's type
4934+
// generic signature.
4935+
if (!derivativeGenSig)
4936+
derivativeGenSig = getOptGenericSignature();
4937+
4938+
// Get differentiability parameter types.
4939+
SmallVector<Type, 8> diffParamTypes;
4940+
autodiff::getSubsetParameterTypes(parameterIndices, this, diffParamTypes,
4941+
/*reverseCurryLevels*/ !makeSelfParamFirst);
4942+
4943+
// Unwrap curry levels. At most, two parameter lists are necessary, for
4944+
// curried method types with a `(Self)` parameter list.
4945+
// TODO(TF-874): Simplify curry level logic.
4946+
SmallVector<AnyFunctionType *, 2> curryLevels;
4947+
auto *currentLevel = castTo<AnyFunctionType>();
4948+
for (unsigned i : range(2)) {
4949+
(void)i;
4950+
if (currentLevel == nullptr)
4951+
break;
4952+
curryLevels.push_back(currentLevel);
4953+
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
4954+
}
4955+
4956+
Type originalResult = curryLevels.back()->getResult();
4957+
4958+
// Build the result linear map function type.
4959+
Type linearMapType;
4960+
switch (kind) {
4961+
case AutoDiffDerivativeFunctionKind::JVP: {
4962+
// Differential function type, a result of the JVP:
4963+
// `LinearMapType = (T.TangentVector, ...) -> (R.TangentVector)`
4964+
SmallVector<AnyFunctionType::Param, 8> differentialParams;
4965+
for (auto diffParamType : diffParamTypes)
4966+
differentialParams.push_back(AnyFunctionType::Param(
4967+
diffParamType->getAutoDiffTangentSpace(lookupConformance)
4968+
->getType()));
4969+
SmallVector<TupleTypeElt, 8> differentialResults;
4970+
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
4971+
auto resultTupleEltType = resultTuple->getElementType(resultIndex);
4972+
differentialResults.push_back(
4973+
resultTupleEltType->getAutoDiffTangentSpace(lookupConformance)
4974+
->getType());
4975+
} else {
4976+
assert(resultIndex == 0 && "resultIndex out of bounds");
4977+
differentialResults.push_back(
4978+
originalResult->getAutoDiffTangentSpace(lookupConformance)
4979+
->getType());
4980+
}
4981+
Type differentialResult = differentialResults.size() > 1
4982+
? TupleType::get(differentialResults, ctx)
4983+
: differentialResults[0].getType();
4984+
linearMapType = FunctionType::get(differentialParams, differentialResult);
4985+
break;
4986+
}
4987+
case AutoDiffDerivativeFunctionKind::VJP: {
4988+
// Pullback function type, a result of the VJP:
4989+
// `LinearMapType = (R.TangentVector) -> (T.TangentVector, ...)`
4990+
SmallVector<AnyFunctionType::Param, 8> pullbackParams;
4991+
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
4992+
auto resultTupleEltType = resultTuple->getElementType(resultIndex);
4993+
pullbackParams.push_back(AnyFunctionType::Param(
4994+
resultTupleEltType->getAutoDiffTangentSpace(lookupConformance)
4995+
->getType()));
4996+
} else {
4997+
assert(resultIndex == 0 &&
4998+
"Expected result index 0 for non-tuple result");
4999+
pullbackParams.push_back(AnyFunctionType::Param(
5000+
originalResult->getAutoDiffTangentSpace(lookupConformance)
5001+
->getType()));
5002+
}
5003+
SmallVector<TupleTypeElt, 8> pullbackResults;
5004+
for (auto diffParamType : diffParamTypes)
5005+
pullbackResults.push_back(
5006+
diffParamType->getAutoDiffTangentSpace(lookupConformance)->getType());
5007+
Type pullbackResult = pullbackResults.size() > 1
5008+
? TupleType::get(pullbackResults, ctx)
5009+
: pullbackResults[0].getType();
5010+
linearMapType = FunctionType::get(pullbackParams, pullbackResult);
5011+
break;
5012+
}
5013+
}
5014+
assert(linearMapType && "Expected linear map type");
5015+
5016+
// Build the full derivative function type: `(T...) -> (R, LinearMapType)`.
5017+
SmallVector<TupleTypeElt, 2> retElts;
5018+
retElts.push_back(originalResult);
5019+
retElts.push_back(linearMapType);
5020+
auto retTy = TupleType::get(retElts, ctx);
5021+
auto *derivativeFunctionType =
5022+
makeFunctionType(curryLevels.back()->getParams(), retTy,
5023+
curryLevels.size() == 1 ? derivativeGenSig : nullptr,
5024+
curryLevels.back()->getExtInfo());
5025+
5026+
// Wrap the derivative function type in additional curry levels.
5027+
auto curryLevelsWithoutLast =
5028+
ArrayRef<AnyFunctionType *>(curryLevels).drop_back(1);
5029+
for (auto pair : enumerate(llvm::reverse(curryLevelsWithoutLast))) {
5030+
unsigned i = pair.index();
5031+
auto *curryLevel = pair.value();
5032+
derivativeFunctionType = makeFunctionType(
5033+
curryLevel->getParams(), derivativeFunctionType,
5034+
i == curryLevelsWithoutLast.size() - 1 ? derivativeGenSig : nullptr,
5035+
curryLevel->getExtInfo());
5036+
}
5037+
5038+
return derivativeFunctionType;
5039+
}
5040+
49135041
CanSILFunctionType
49145042
SILFunctionType::withSubstitutions(SubstitutionMap subs) const {
49155043
return SILFunctionType::get(getSubstGenericSignature(),

0 commit comments

Comments
 (0)