@@ -985,7 +985,7 @@ static ValueDecl *getGetObjCTypeEncodingOperation(ASTContext &Context,
985
985
// SWIFT_ENABLE_TENSORFLOW
986
986
static ValueDecl *getAutoDiffApplyDerivativeFunction (
987
987
ASTContext &Context, Identifier Id, AutoDiffDerivativeFunctionKind kind,
988
- unsigned arity, bool rethrows ) {
988
+ unsigned arity, bool throws ) {
989
989
assert (arity >= 1 );
990
990
// JVP:
991
991
// <...T...(arity), R> (@differentiable (...T) throws -> R, ...T)
@@ -1000,54 +1000,114 @@ static ValueDecl *getAutoDiffApplyDerivativeFunction(
1000
1000
// Create type parameters and add conformance constraints.
1001
1001
auto fnResultGen = makeGenericParam (arity);
1002
1002
builder.addConformanceRequirement (fnResultGen, diffableProto);
1003
- SmallVector<decltype (fnResultGen), 2 > fnArgGens ;
1003
+ SmallVector<decltype (fnResultGen), 2 > fnParamGens ;
1004
1004
for (auto i : range (arity)) {
1005
1005
auto T = makeGenericParam (i);
1006
1006
builder.addConformanceRequirement (T, diffableProto);
1007
- fnArgGens .push_back (T);
1007
+ fnParamGens .push_back (T);
1008
1008
}
1009
- // Generator for the first argument, i.e. the @differentiable function.
1009
+ // Generator for the first argument, i.e. the ` @differentiable` function.
1010
1010
BuiltinFunctionBuilder::LambdaGenerator firstArgGen {
1011
1011
// Generator for the function type at the argument position, i.e. the
1012
1012
// function being differentiated.
1013
- [=, &fnArgGens ](BuiltinFunctionBuilder &builder) -> Type {
1013
+ [=, &fnParamGens ](BuiltinFunctionBuilder &builder) -> Type {
1014
1014
FunctionType::ExtInfo ext;
1015
1015
auto extInfo = FunctionType::ExtInfo ()
1016
1016
.withDifferentiabilityKind (DifferentiabilityKind::Normal)
1017
- .withNoEscape ().withThrows (rethrows );
1017
+ .withNoEscape ().withThrows (throws );
1018
1018
SmallVector<FunctionType::Param, 2 > params;
1019
- for (auto ¶mGen : fnArgGens )
1019
+ for (auto ¶mGen : fnParamGens )
1020
1020
params.push_back (FunctionType::Param (paramGen.build (builder)));
1021
1021
auto innerFunction = FunctionType::get (params,
1022
1022
fnResultGen.build (builder));
1023
1023
return innerFunction->withExtInfo (extInfo);
1024
1024
}
1025
1025
};
1026
1026
// Eagerly build the type of the first arg, then use that to compute the type
1027
- // of the derivative function type .
1028
- auto *origFnTy =
1027
+ // of the result .
1028
+ auto *diffFnType =
1029
1029
firstArgGen.build (builder)->castTo <AnyFunctionType>();
1030
- origFnTy = origFnTy ->getWithoutDifferentiability ()->withExtInfo (
1031
- origFnTy ->getExtInfo ().withNoEscape (false ));
1030
+ diffFnType = diffFnType ->getWithoutDifferentiability ()->withExtInfo (
1031
+ diffFnType ->getExtInfo ().withNoEscape (false ));
1032
1032
auto *paramIndices = IndexSubset::get (
1033
- Context, SmallBitVector (origFnTy ->getNumParams (), true ));
1033
+ Context, SmallBitVector (diffFnType ->getNumParams (), true ));
1034
1034
// Generator for the resultant function type, i.e. the AD derivative function.
1035
1035
BuiltinFunctionBuilder::LambdaGenerator resultGen{
1036
1036
[=, &Context](BuiltinFunctionBuilder &builder) -> Type {
1037
- auto derivativeFnTy = origFnTy ->getAutoDiffDerivativeFunctionType (
1037
+ auto derivativeFnTy = diffFnType ->getAutoDiffDerivativeFunctionType (
1038
1038
paramIndices, /* resultIndex*/ 0 , kind,
1039
1039
LookUpConformanceInModule (Context.TheBuiltinModule ));
1040
1040
return derivativeFnTy->getResult ();
1041
1041
}};
1042
1042
builder.addParameter (firstArgGen);
1043
- for (auto argGen : fnArgGens )
1043
+ for (auto argGen : fnParamGens )
1044
1044
builder.addParameter (argGen);
1045
- if (rethrows )
1045
+ if (throws )
1046
1046
builder.setRethrows ();
1047
1047
builder.setResult (resultGen);
1048
1048
return builder.build (Id);
1049
1049
}
1050
1050
1051
+ static ValueDecl *getAutoDiffApplyTransposeFunction (
1052
+ ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1053
+ assert (arity >= 1 );
1054
+ // <...T...(arity), R>
1055
+ // (@differentiable (...T) throws -> R, ...R.TangentVector)
1056
+ // rethrows -> (...T.TangentVector)
1057
+ unsigned numGenericParams = 1 + arity;
1058
+ BuiltinFunctionBuilder builder (Context, numGenericParams);
1059
+ auto *diffableProto = Context.getProtocol (KnownProtocolKind::Differentiable);
1060
+ auto *addArithProto =
1061
+ Context.getProtocol (KnownProtocolKind::AdditiveArithmetic);
1062
+ // Create type parameters and add conformance constraints.
1063
+ auto linearFnResultGen = makeGenericParam (arity);
1064
+ builder.addConformanceRequirement (linearFnResultGen, diffableProto);
1065
+ builder.addConformanceRequirement (linearFnResultGen, addArithProto);
1066
+ SmallVector<decltype (linearFnResultGen), 2 > linearFnParamGens;
1067
+ for (auto i : range (arity)) {
1068
+ auto T = makeGenericParam (i);
1069
+ builder.addConformanceRequirement (T, diffableProto);
1070
+ builder.addConformanceRequirement (T, addArithProto);
1071
+ linearFnParamGens.push_back (T);
1072
+ }
1073
+ // Generator for the first argument, i.e. the `@differentiable(linear)`
1074
+ // function.
1075
+ BuiltinFunctionBuilder::LambdaGenerator firstArgGen {
1076
+ // Generator for the function type at the argument position, i.e. the
1077
+ // function being differentiated.
1078
+ [=, &linearFnParamGens](BuiltinFunctionBuilder &builder) -> Type {
1079
+ FunctionType::ExtInfo ext;
1080
+ auto extInfo = FunctionType::ExtInfo ()
1081
+ .withDifferentiabilityKind (DifferentiabilityKind::Linear)
1082
+ .withNoEscape ().withThrows (throws);
1083
+ SmallVector<FunctionType::Param, 2 > params;
1084
+ for (auto ¶mGen : linearFnParamGens)
1085
+ params.push_back (FunctionType::Param (paramGen.build (builder)));
1086
+ auto innerFunction = FunctionType::get (params,
1087
+ linearFnResultGen.build (builder));
1088
+ return innerFunction->withExtInfo (extInfo);
1089
+ }
1090
+ };
1091
+ builder.addParameter (firstArgGen);
1092
+ builder.addParameter (linearFnResultGen);
1093
+ if (throws)
1094
+ builder.setRethrows ();
1095
+ if (arity == 1 )
1096
+ builder.setResult (linearFnParamGens.front ());
1097
+ else {
1098
+ BuiltinFunctionBuilder::LambdaGenerator tupleResultGen {
1099
+ [&](BuiltinFunctionBuilder &builder) -> Type {
1100
+ SmallVector<TupleTypeElt, 2 > tupleElts;
1101
+ for (auto linearFnParamGen : linearFnParamGens)
1102
+ tupleElts.push_back (linearFnParamGen.build (builder));
1103
+ return TupleType::get (tupleElts, Context);
1104
+ }
1105
+ };
1106
+ builder.setResult (tupleResultGen);
1107
+ }
1108
+ return builder.build (Id);
1109
+ }
1110
+
1051
1111
static ValueDecl *getDifferentiableFunctionConstructor (
1052
1112
ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1053
1113
assert (arity >= 1 );
@@ -1992,15 +2052,23 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
1992
2052
return getAllocWithTailElemsOperation (Context, Id, NumTailTypes);
1993
2053
}
1994
2054
// SWIFT_ENABLE_TENSORFLOW
1995
- if (OperationName.startswith (" autodiffApply_ " )) {
2055
+ if (OperationName.startswith (" applyDerivative_ " )) {
1996
2056
AutoDiffDerivativeFunctionKind kind;
1997
2057
unsigned arity;
1998
- bool rethrows ;
1999
- if (!autodiff::getBuiltinAutoDiffApplyConfig (OperationName, kind, arity,
2000
- rethrows ))
2058
+ bool throws ;
2059
+ if (!autodiff::getBuiltinApplyDerivativeConfig (
2060
+ OperationName, kind, arity, throws ))
2001
2061
return nullptr ;
2002
2062
return getAutoDiffApplyDerivativeFunction (Context, Id, kind, arity,
2003
- rethrows);
2063
+ throws);
2064
+ }
2065
+ if (OperationName.startswith (" applyTranspose_" )) {
2066
+ unsigned arity;
2067
+ bool throws;
2068
+ if (!autodiff::getBuiltinApplyTransposeConfig (
2069
+ OperationName, arity, throws))
2070
+ return nullptr ;
2071
+ return getAutoDiffApplyTransposeFunction (Context, Id, arity, throws);
2004
2072
}
2005
2073
if (OperationName.startswith (" differentiableFunction_" )) {
2006
2074
unsigned arity;
@@ -2288,7 +2356,8 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
2288
2356
return getUnsafeGuaranteedEnd (Context, Id);
2289
2357
2290
2358
// SWIFT_ENABLE_TENSORFLOW
2291
- case BuiltinValueKind::AutoDiffApply:
2359
+ case BuiltinValueKind::ApplyDerivative:
2360
+ case BuiltinValueKind::ApplyTranspose:
2292
2361
case BuiltinValueKind::DifferentiableFunction:
2293
2362
case BuiltinValueKind::LinearFunction:
2294
2363
llvm_unreachable (" Handled above" );
0 commit comments