@@ -1048,6 +1048,169 @@ static ValueDecl *getAutoDiffApplyDerivativeFunction(
1048
1048
return builder.build (Id);
1049
1049
}
1050
1050
1051
+ static ValueDecl *getDifferentiableFunctionConstructor (
1052
+ ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1053
+ assert (arity >= 1 );
1054
+ unsigned numGenericParams = 1 + arity;
1055
+ BuiltinFunctionBuilder builder (Context, numGenericParams);
1056
+ // Get the `Differentiable` and `AdditiveArithmetic` protocols.
1057
+ auto *diffableProto =
1058
+ Context.getProtocol (KnownProtocolKind::Differentiable);
1059
+ auto *tangentVectorDecl =
1060
+ diffableProto->getAssociatedType (Context.Id_TangentVector );
1061
+ assert (tangentVectorDecl);
1062
+ // Create type parameters and add conformance constraints.
1063
+ auto origResultGen = makeGenericParam (arity);
1064
+ builder.addConformanceRequirement (origResultGen, diffableProto);
1065
+ SmallVector<decltype (origResultGen), 2 > fnArgGens;
1066
+ for (auto i : range (arity)) {
1067
+ auto T = makeGenericParam (i);
1068
+ builder.addConformanceRequirement (T, diffableProto);
1069
+ fnArgGens.push_back (T);
1070
+ }
1071
+
1072
+ BuiltinFunctionBuilder::LambdaGenerator origFnGen {
1073
+ [=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
1074
+ SmallVector<FunctionType::Param, 2 > params;
1075
+ for (auto ¶mGen : fnArgGens)
1076
+ params.push_back (FunctionType::Param (paramGen.build (builder)));
1077
+ return FunctionType::get (params, origResultGen.build (builder))
1078
+ ->withExtInfo (
1079
+ FunctionType::ExtInfo (FunctionTypeRepresentation::Swift, throws));
1080
+ }
1081
+ };
1082
+
1083
+ BuiltinFunctionBuilder::LambdaGenerator jvpGen {
1084
+ [=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1085
+ SmallVector<FunctionType::Param, 2 > params;
1086
+ for (auto ¶mGen : fnArgGens)
1087
+ params.push_back (FunctionType::Param (paramGen.build (builder)));
1088
+ auto origResultType = origResultGen.build (builder);
1089
+ SmallVector<FunctionType::Param, 2 > differentialParams;
1090
+ for (auto ¶m : params) {
1091
+ auto tanType = DependentMemberType::get (
1092
+ param.getPlainType (), tangentVectorDecl);
1093
+ differentialParams.push_back (FunctionType::Param (tanType));
1094
+ }
1095
+ auto differentialResultType = DependentMemberType::get (
1096
+ origResultType, tangentVectorDecl);
1097
+ auto differentialType =
1098
+ FunctionType::get ({differentialParams}, differentialResultType);
1099
+ auto jvpResultType = TupleType::get (
1100
+ {TupleTypeElt (origResultType, Context.Id_value ),
1101
+ TupleTypeElt (differentialType, Context.Id_differential )}, Context);
1102
+ return FunctionType::get (params, jvpResultType)
1103
+ ->withExtInfo (
1104
+ FunctionType::ExtInfo (FunctionTypeRepresentation::Swift, throws));
1105
+ }
1106
+ };
1107
+
1108
+ BuiltinFunctionBuilder::LambdaGenerator vjpGen {
1109
+ [=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1110
+ SmallVector<FunctionType::Param, 2 > params;
1111
+ for (auto ¶mGen : fnArgGens)
1112
+ params.push_back (FunctionType::Param (paramGen.build (builder)));
1113
+ auto origResultType = origResultGen.build (builder);
1114
+ SmallVector<TupleTypeElt, 2 > pullbackResultTupleElts;
1115
+ for (auto ¶m : params) {
1116
+ auto tanType = DependentMemberType::get (
1117
+ param.getPlainType (), tangentVectorDecl);
1118
+ pullbackResultTupleElts.push_back (TupleTypeElt (tanType));
1119
+ }
1120
+ auto pullbackParam = FunctionType::Param (
1121
+ DependentMemberType::get (origResultType, tangentVectorDecl));
1122
+ auto pullbackType = FunctionType::get (
1123
+ {pullbackParam},
1124
+ pullbackResultTupleElts.size () == 1
1125
+ ? pullbackResultTupleElts.front ().getType ()
1126
+ : TupleType::get (pullbackResultTupleElts, Context));
1127
+ auto vjpResultType = TupleType::get (
1128
+ {TupleTypeElt (origResultType, Context.Id_value ),
1129
+ TupleTypeElt (pullbackType, Context.Id_pullback )}, Context);
1130
+ return FunctionType::get (params, vjpResultType)
1131
+ ->withExtInfo (
1132
+ FunctionType::ExtInfo (FunctionTypeRepresentation::Swift, throws));
1133
+ }
1134
+ };
1135
+
1136
+ BuiltinFunctionBuilder::LambdaGenerator resultGen {
1137
+ [&](BuiltinFunctionBuilder &builder) -> Type {
1138
+ auto origFnType = origFnGen.build (builder)->castTo <FunctionType>();
1139
+ return origFnType->withExtInfo (
1140
+ origFnType->getExtInfo ()
1141
+ .withDifferentiabilityKind (DifferentiabilityKind::Normal));
1142
+ }
1143
+ };
1144
+
1145
+ builder.addParameter (origFnGen, ValueOwnership::Owned);
1146
+ builder.addParameter (jvpGen, ValueOwnership::Owned);
1147
+ builder.addParameter (vjpGen, ValueOwnership::Owned);
1148
+ builder.setResult (resultGen);
1149
+ return builder.build (Id);
1150
+ }
1151
+
1152
+ static ValueDecl *getLinearFunctionConstructor (
1153
+ ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1154
+ assert (arity >= 1 );
1155
+ unsigned numGenericParams = 1 + arity;
1156
+ BuiltinFunctionBuilder builder (Context, numGenericParams);
1157
+ // Get the `Differentiable` and `AdditiveArithmetic` protocols.
1158
+ auto *diffableProto =
1159
+ Context.getProtocol (KnownProtocolKind::Differentiable);
1160
+ auto *addArithProto =
1161
+ Context.getProtocol (KnownProtocolKind::AdditiveArithmetic);
1162
+ // Create type parameters and add conformance constraints.
1163
+ auto origResultGen = makeGenericParam (arity);
1164
+ builder.addConformanceRequirement (origResultGen, diffableProto);
1165
+ builder.addConformanceRequirement (origResultGen, addArithProto);
1166
+ SmallVector<decltype (origResultGen), 2 > fnArgGens;
1167
+ for (auto i : range (arity)) {
1168
+ auto T = makeGenericParam (i);
1169
+ builder.addConformanceRequirement (T, diffableProto);
1170
+ builder.addConformanceRequirement (T, addArithProto);
1171
+ fnArgGens.push_back (T);
1172
+ }
1173
+
1174
+ BuiltinFunctionBuilder::LambdaGenerator origFnGen {
1175
+ [=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
1176
+ SmallVector<FunctionType::Param, 2 > params;
1177
+ for (auto ¶mGen : fnArgGens)
1178
+ params.push_back (FunctionType::Param (paramGen.build (builder)));
1179
+ return FunctionType::get (params, origResultGen.build (builder))
1180
+ ->withExtInfo (
1181
+ FunctionType::ExtInfo (FunctionTypeRepresentation::Swift, throws));
1182
+ }
1183
+ };
1184
+
1185
+ BuiltinFunctionBuilder::LambdaGenerator transposeFnGen {
1186
+ [=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1187
+ auto origResultType = origResultGen.build (builder);
1188
+ SmallVector<TupleTypeElt, 2 > resultTupleElts;
1189
+ for (auto ¶mGen : fnArgGens)
1190
+ resultTupleElts.push_back (paramGen.build (builder));
1191
+ return FunctionType::get (
1192
+ {FunctionType::Param (origResultType)},
1193
+ resultTupleElts.size () == 1
1194
+ ? resultTupleElts.front ().getType ()
1195
+ : TupleType::get (resultTupleElts, Context));
1196
+ }
1197
+ };
1198
+
1199
+ BuiltinFunctionBuilder::LambdaGenerator resultGen {
1200
+ [&](BuiltinFunctionBuilder &builder) -> Type {
1201
+ auto origFnType = origFnGen.build (builder)->castTo <FunctionType>();
1202
+ return origFnType->withExtInfo (
1203
+ origFnType->getExtInfo ()
1204
+ .withDifferentiabilityKind (DifferentiabilityKind::Linear));
1205
+ }
1206
+ };
1207
+
1208
+ builder.addParameter (origFnGen, ValueOwnership::Owned);
1209
+ builder.addParameter (transposeFnGen, ValueOwnership::Owned);
1210
+ builder.setResult (resultGen);
1211
+ return builder.build (Id);
1212
+ }
1213
+
1051
1214
static ValueDecl *getGlobalStringTablePointer (ASTContext &Context,
1052
1215
Identifier Id) {
1053
1216
// String -> Builtin.RawPointer
@@ -1839,6 +2002,22 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
1839
2002
return getAutoDiffApplyDerivativeFunction (Context, Id, kind, arity,
1840
2003
rethrows);
1841
2004
}
2005
+ if (OperationName.startswith (" differentiableFunction_" )) {
2006
+ unsigned arity;
2007
+ bool throws;
2008
+ if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig (
2009
+ OperationName, arity, throws))
2010
+ return nullptr ;
2011
+ return getDifferentiableFunctionConstructor (Context, Id, arity, throws);
2012
+ }
2013
+ if (OperationName.startswith (" linearFunction_" )) {
2014
+ unsigned arity;
2015
+ bool throws;
2016
+ if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig (
2017
+ OperationName, arity, throws))
2018
+ return nullptr ;
2019
+ return getLinearFunctionConstructor (Context, Id, arity, throws);
2020
+ }
1842
2021
auto BV = llvm::StringSwitch<BuiltinValueKind>(OperationName)
1843
2022
#define BUILTIN (id, name, Attrs ) .Case(name, BuiltinValueKind::id)
1844
2023
#include " swift/AST/Builtins.def"
@@ -2110,6 +2289,8 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
2110
2289
2111
2290
// SWIFT_ENABLE_TENSORFLOW
2112
2291
case BuiltinValueKind::AutoDiffApply:
2292
+ case BuiltinValueKind::DifferentiableFunction:
2293
+ case BuiltinValueKind::LinearFunction:
2113
2294
llvm_unreachable (" Handled above" );
2114
2295
2115
2296
case BuiltinValueKind::OnFastPath:
0 commit comments