@@ -4910,6 +4910,134 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
4910
4910
return cache (None);
4911
4911
}
4912
4912
4913
+ // Creates an `AnyFunctionType` from the given parameters, result type,
4914
+ // generic signature, and `ExtInfo`.
4915
+ static AnyFunctionType *
4916
+ makeFunctionType (ArrayRef<AnyFunctionType::Param> parameters, Type resultType,
4917
+ GenericSignature genericSignature,
4918
+ AnyFunctionType::ExtInfo extInfo) {
4919
+ if (genericSignature)
4920
+ return GenericFunctionType::get (genericSignature, parameters, resultType,
4921
+ extInfo);
4922
+ return FunctionType::get (parameters, resultType, extInfo);
4923
+ }
4924
+
4925
+ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType (
4926
+ IndexSubset *parameterIndices, unsigned resultIndex,
4927
+ AutoDiffDerivativeFunctionKind kind, LookupConformanceFn lookupConformance,
4928
+ GenericSignature derivativeGenSig, bool makeSelfParamFirst) {
4929
+ assert (!parameterIndices->isEmpty () &&
4930
+ " Expected at least one differentiability parameter" );
4931
+ auto &ctx = getASTContext ();
4932
+
4933
+ // If `derivativeGenSig` is not defined, use the current function's type
4934
+ // generic signature.
4935
+ if (!derivativeGenSig)
4936
+ derivativeGenSig = getOptGenericSignature ();
4937
+
4938
+ // Get differentiability parameter types.
4939
+ SmallVector<Type, 8 > diffParamTypes;
4940
+ autodiff::getSubsetParameterTypes (parameterIndices, this , diffParamTypes,
4941
+ /* reverseCurryLevels*/ !makeSelfParamFirst);
4942
+
4943
+ // Unwrap curry levels. At most, two parameter lists are necessary, for
4944
+ // curried method types with a `(Self)` parameter list.
4945
+ // TODO(TF-874): Simplify curry level logic.
4946
+ SmallVector<AnyFunctionType *, 2 > curryLevels;
4947
+ auto *currentLevel = castTo<AnyFunctionType>();
4948
+ for (unsigned i : range (2 )) {
4949
+ (void )i;
4950
+ if (currentLevel == nullptr )
4951
+ break ;
4952
+ curryLevels.push_back (currentLevel);
4953
+ currentLevel = currentLevel->getResult ()->getAs <AnyFunctionType>();
4954
+ }
4955
+
4956
+ Type originalResult = curryLevels.back ()->getResult ();
4957
+
4958
+ // Build the result linear map function type.
4959
+ Type linearMapType;
4960
+ switch (kind) {
4961
+ case AutoDiffDerivativeFunctionKind::JVP: {
4962
+ // Differential function type, a result of the JVP:
4963
+ // `LinearMapType = (T.TangentVector, ...) -> (R.TangentVector)`
4964
+ SmallVector<AnyFunctionType::Param, 8 > differentialParams;
4965
+ for (auto diffParamType : diffParamTypes)
4966
+ differentialParams.push_back (AnyFunctionType::Param (
4967
+ diffParamType->getAutoDiffTangentSpace (lookupConformance)
4968
+ ->getType ()));
4969
+ SmallVector<TupleTypeElt, 8 > differentialResults;
4970
+ if (auto *resultTuple = originalResult->getAs <TupleType>()) {
4971
+ auto resultTupleEltType = resultTuple->getElementType (resultIndex);
4972
+ differentialResults.push_back (
4973
+ resultTupleEltType->getAutoDiffTangentSpace (lookupConformance)
4974
+ ->getType ());
4975
+ } else {
4976
+ assert (resultIndex == 0 && " resultIndex out of bounds" );
4977
+ differentialResults.push_back (
4978
+ originalResult->getAutoDiffTangentSpace (lookupConformance)
4979
+ ->getType ());
4980
+ }
4981
+ Type differentialResult = differentialResults.size () > 1
4982
+ ? TupleType::get (differentialResults, ctx)
4983
+ : differentialResults[0 ].getType ();
4984
+ linearMapType = FunctionType::get (differentialParams, differentialResult);
4985
+ break ;
4986
+ }
4987
+ case AutoDiffDerivativeFunctionKind::VJP: {
4988
+ // Pullback function type, a result of the VJP:
4989
+ // `LinearMapType = (R.TangentVector) -> (T.TangentVector, ...)`
4990
+ SmallVector<AnyFunctionType::Param, 8 > pullbackParams;
4991
+ if (auto *resultTuple = originalResult->getAs <TupleType>()) {
4992
+ auto resultTupleEltType = resultTuple->getElementType (resultIndex);
4993
+ pullbackParams.push_back (AnyFunctionType::Param (
4994
+ resultTupleEltType->getAutoDiffTangentSpace (lookupConformance)
4995
+ ->getType ()));
4996
+ } else {
4997
+ assert (resultIndex == 0 &&
4998
+ " Expected result index 0 for non-tuple result" );
4999
+ pullbackParams.push_back (AnyFunctionType::Param (
5000
+ originalResult->getAutoDiffTangentSpace (lookupConformance)
5001
+ ->getType ()));
5002
+ }
5003
+ SmallVector<TupleTypeElt, 8 > pullbackResults;
5004
+ for (auto diffParamType : diffParamTypes)
5005
+ pullbackResults.push_back (
5006
+ diffParamType->getAutoDiffTangentSpace (lookupConformance)->getType ());
5007
+ Type pullbackResult = pullbackResults.size () > 1
5008
+ ? TupleType::get (pullbackResults, ctx)
5009
+ : pullbackResults[0 ].getType ();
5010
+ linearMapType = FunctionType::get (pullbackParams, pullbackResult);
5011
+ break ;
5012
+ }
5013
+ }
5014
+ assert (linearMapType && " Expected linear map type" );
5015
+
5016
+ // Build the full derivative function type: `(T...) -> (R, LinearMapType)`.
5017
+ SmallVector<TupleTypeElt, 2 > retElts;
5018
+ retElts.push_back (originalResult);
5019
+ retElts.push_back (linearMapType);
5020
+ auto retTy = TupleType::get (retElts, ctx);
5021
+ auto *derivativeFunctionType =
5022
+ makeFunctionType (curryLevels.back ()->getParams (), retTy,
5023
+ curryLevels.size () == 1 ? derivativeGenSig : nullptr ,
5024
+ curryLevels.back ()->getExtInfo ());
5025
+
5026
+ // Wrap the derivative function type in additional curry levels.
5027
+ auto curryLevelsWithoutLast =
5028
+ ArrayRef<AnyFunctionType *>(curryLevels).drop_back (1 );
5029
+ for (auto pair : enumerate(llvm::reverse (curryLevelsWithoutLast))) {
5030
+ unsigned i = pair.index ();
5031
+ auto *curryLevel = pair.value ();
5032
+ derivativeFunctionType = makeFunctionType (
5033
+ curryLevel->getParams (), derivativeFunctionType,
5034
+ i == curryLevelsWithoutLast.size () - 1 ? derivativeGenSig : nullptr ,
5035
+ curryLevel->getExtInfo ());
5036
+ }
5037
+
5038
+ return derivativeFunctionType;
5039
+ }
5040
+
4913
5041
CanSILFunctionType
4914
5042
SILFunctionType::withSubstitutions (SubstitutionMap subs) const {
4915
5043
return SILFunctionType::get (getSubstGenericSignature (),
0 commit comments