Skip to content

Commit e0789da

Browse files
authored
Add builtin functions that construct a @differentiable or @differentiable(linear) function from component functons. (#28467)
Add builtin functions that construct a `@differentiable` or `@differentiable(linear)` function from component functons. * `Builtin.differentiableFunction_*` Takes an original function, a JVP function and a VJP function and returns a `@differentiable` function. Pseudo-declaration: ```swift func differentiableFunction_arity{arity}[_throws]?{throws}<T...{arity}, U>( _ original: __owned @escaping (T...{arity}) {throws}? -> U, _ jvp: __owned @escaping (T...{arity}) {throws}? -> (value: U, differential: (T.TangentVector...{arity}) -> U.TangentVector), _ vjp: __owned @escaping (T...{arity}) {throws}? -> (value: U, pullback: (U.TangentVector) -> (T.TangentVector...{arity})) ) -> @differentiable (T...{arity}) {throws}? -> U where T...{arity} : Differentiable, U : Differentiable ``` * `Builtin.linearFunction_*` Takes an original function and a transpose function and returns a `@differentiable` function. Pseudo-declaration: ```swift func linearFunction_arity{arity}[_throws]?{throws}<T...{arity}, U>( _ original: __owned @escaping (T...{arity}) {throws}? -> U, _ transpose: __owned @escaping (U.TangentVector) {throws}? -> (T.TangentVector...{arity}) ) -> @differentiable (T...{arity}) {throws}? -> U where T...{arity} : Differentiable & AdditiveArithmetic, U : Differentiable & AdditiveArithmetic ``` These builtins will be used to write unit tests for `@differentiable` and `@differentiable(linear)` function types that do not necessarily depend on the differentiation transform. TODO: - SR-11848: For robustness, we need SIL FileCheck tests for all AD builtins. These have not been added for `Builtin.autodiffApply*`, so I'm leaving this as a future task. - SR-11847: Update `differentiableFunction(from:)` to use `Builtin.differentiableFunction*` in its implementation. - SR-11849: Disallow non-top-level derivative registration. Resolves SR-11846.
1 parent 21a4bc5 commit e0789da

File tree

9 files changed

+361
-14
lines changed

9 files changed

+361
-14
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,19 @@ IndexSubset *getLoweredParameterIndices(IndexSubset *indices,
352352
AnyFunctionType *type);
353353

354354
/// Retrieve config from the function name of a variant of
355-
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
355+
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2`.
356356
/// Returns true if the function name is parsed successfully.
357357
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
358358
AutoDiffDerivativeFunctionKind &kind,
359359
unsigned &arity, bool &rethrows);
360360

361+
/// Retrieve config from the function name of a variant of
362+
/// `Builtin.differentiableFunction` or `Builtin.linearFunction`, e.g.
363+
/// `Builtin.differentiableFunction_arity1_throws`.
364+
/// Returns true if the function name is parsed successfully.
365+
bool getBuiltinDifferentiableOrLinearFunctionConfig(
366+
StringRef operationName, unsigned &arity, bool &throws);
367+
361368
/// Computes the correct linkage for a derivative function given the linkage of
362369
/// the original function. If the original linkage is not external and
363370
/// `isDerivativeFnExported` is true, use the original function's linkage.

include/swift/AST/Builtins.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@ BUILTIN_SIL_OPERATION(ProjectTailElems, "projectTailElems", Special)
448448
/// autodiffApply
449449
BUILTIN_SIL_OPERATION(AutoDiffApply, "autodiffApply", Special)
450450

451+
/// differentiableFunction
452+
BUILTIN_SIL_OPERATION(DifferentiableFunction, "differentiableFunction", Special)
453+
454+
/// linearFunction
455+
BUILTIN_SIL_OPERATION(LinearFunction, "linearFunction", Special)
456+
451457
#undef BUILTIN_SIL_OPERATION
452458

453459
// BUILTIN_RUNTIME_CALL - A call into a runtime function.

include/swift/AST/KnownIdentifiers.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ IDENTIFIER(withKeywordArguments)
130130
IDENTIFIER(wrapped)
131131
IDENTIFIER(wrappedValue)
132132
IDENTIFIER(wrapperValue)
133+
// SWIFT_ENABLE_TENSORFLOW
134+
IDENTIFIER(differential)
135+
IDENTIFIER(pullback)
133136

134137
// SWIFT_ENABLE_TENSORFLOW
135138
IDENTIFIER(TensorFlow)

lib/AST/AutoDiff.cpp

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,22 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
209209
bool autodiff::getBuiltinAutoDiffApplyConfig(
210210
StringRef operationName, AutoDiffDerivativeFunctionKind &kind,
211211
unsigned &arity, bool &rethrows) {
212-
if (!operationName.startswith("autodiffApply_"))
212+
constexpr char prefix[] = "autodiffApply";
213+
if (!operationName.startswith(prefix))
213214
return false;
214-
operationName = operationName.drop_front(strlen("autodiffApply_"));
215+
operationName = operationName.drop_front(sizeof(prefix) - 1);
215216
// Parse 'jvp' or 'vjp'.
216-
if (operationName.startswith("jvp"))
217+
constexpr char jvpPrefix[] = "_jvp";
218+
constexpr char vjpPrefix[] = "_vjp";
219+
if (operationName.startswith(jvpPrefix))
217220
kind = AutoDiffDerivativeFunctionKind::JVP;
218-
else if (operationName.startswith("vjp"))
221+
else if (operationName.startswith(vjpPrefix))
219222
kind = AutoDiffDerivativeFunctionKind::VJP;
220-
operationName = operationName.drop_front(3);
223+
operationName = operationName.drop_front(sizeof(jvpPrefix) - 1);
221224
// Parse '_arity'.
222-
if (operationName.startswith("_arity")) {
223-
operationName = operationName.drop_front(strlen("_arity"));
225+
constexpr char arityPrefix[] = "_arity";
226+
if (operationName.startswith(arityPrefix)) {
227+
operationName = operationName.drop_front(sizeof(arityPrefix) - 1);
224228
auto arityStr = operationName.take_while(llvm::isDigit);
225229
operationName = operationName.drop_front(arityStr.size());
226230
auto converted = llvm::to_integer(arityStr, arity);
@@ -230,15 +234,49 @@ bool autodiff::getBuiltinAutoDiffApplyConfig(
230234
arity = 1;
231235
}
232236
// Parse '_rethrows'.
233-
if (operationName.startswith("_rethrows")) {
234-
operationName = operationName.drop_front(strlen("_rethrows"));
237+
constexpr char rethrowsPrefix[] = "_rethrows";
238+
if (operationName.startswith(rethrowsPrefix)) {
239+
operationName = operationName.drop_front(sizeof(rethrowsPrefix) - 1);
235240
rethrows = true;
236241
} else {
237242
rethrows = false;
238243
}
239244
return operationName.empty();
240245
}
241246

247+
bool autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
248+
StringRef operationName, unsigned &arity, bool &throws) {
249+
constexpr char differentiablePrefix[] = "differentiableFunction";
250+
constexpr char linearPrefix[] = "linearFunction";
251+
if (operationName.startswith(differentiablePrefix))
252+
operationName = operationName.drop_front(sizeof(differentiablePrefix) - 1);
253+
else if (operationName.startswith(linearPrefix))
254+
operationName = operationName.drop_front(sizeof(linearPrefix) - 1);
255+
else
256+
return false;
257+
// Parse '_arity'.
258+
constexpr char arityPrefix[] = "_arity";
259+
if (operationName.startswith(arityPrefix)) {
260+
operationName = operationName.drop_front(sizeof(arityPrefix) - 1);
261+
auto arityStr = operationName.take_while(llvm::isDigit);
262+
operationName = operationName.drop_front(arityStr.size());
263+
auto converted = llvm::to_integer(arityStr, arity);
264+
assert(converted); (void)converted;
265+
assert(arity > 0);
266+
} else {
267+
arity = 1;
268+
}
269+
// Parse '_throws'.
270+
constexpr char throwsPrefix[] = "_throws";
271+
if (operationName.startswith(throwsPrefix)) {
272+
operationName = operationName.drop_front(sizeof(throwsPrefix) - 1);
273+
throws = true;
274+
} else {
275+
throws = false;
276+
}
277+
return operationName.empty();
278+
}
279+
242280
SILLinkage autodiff::getAutoDiffDerivativeFunctionLinkage(
243281
SILLinkage originalLinkage, bool isDerivativeFnExported) {
244282
// If the original is defined externally, then the AD pass is just generating

lib/AST/Builtins.cpp

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,169 @@ static ValueDecl *getAutoDiffApplyDerivativeFunction(
10481048
return builder.build(Id);
10491049
}
10501050

1051+
static ValueDecl *getDifferentiableFunctionConstructor(
1052+
ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1053+
assert(arity >= 1);
1054+
unsigned numGenericParams = 1 + arity;
1055+
BuiltinFunctionBuilder builder(Context, numGenericParams);
1056+
// Get the `Differentiable` and `AdditiveArithmetic` protocols.
1057+
auto *diffableProto =
1058+
Context.getProtocol(KnownProtocolKind::Differentiable);
1059+
auto *tangentVectorDecl =
1060+
diffableProto->getAssociatedType(Context.Id_TangentVector);
1061+
assert(tangentVectorDecl);
1062+
// Create type parameters and add conformance constraints.
1063+
auto origResultGen = makeGenericParam(arity);
1064+
builder.addConformanceRequirement(origResultGen, diffableProto);
1065+
SmallVector<decltype(origResultGen), 2> fnArgGens;
1066+
for (auto i : range(arity)) {
1067+
auto T = makeGenericParam(i);
1068+
builder.addConformanceRequirement(T, diffableProto);
1069+
fnArgGens.push_back(T);
1070+
}
1071+
1072+
BuiltinFunctionBuilder::LambdaGenerator origFnGen {
1073+
[=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
1074+
SmallVector<FunctionType::Param, 2> params;
1075+
for (auto &paramGen : fnArgGens)
1076+
params.push_back(FunctionType::Param(paramGen.build(builder)));
1077+
return FunctionType::get(params, origResultGen.build(builder))
1078+
->withExtInfo(
1079+
FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws));
1080+
}
1081+
};
1082+
1083+
BuiltinFunctionBuilder::LambdaGenerator jvpGen {
1084+
[=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1085+
SmallVector<FunctionType::Param, 2> params;
1086+
for (auto &paramGen : fnArgGens)
1087+
params.push_back(FunctionType::Param(paramGen.build(builder)));
1088+
auto origResultType = origResultGen.build(builder);
1089+
SmallVector<FunctionType::Param, 2> differentialParams;
1090+
for (auto &param : params) {
1091+
auto tanType = DependentMemberType::get(
1092+
param.getPlainType(), tangentVectorDecl);
1093+
differentialParams.push_back(FunctionType::Param(tanType));
1094+
}
1095+
auto differentialResultType = DependentMemberType::get(
1096+
origResultType, tangentVectorDecl);
1097+
auto differentialType =
1098+
FunctionType::get({differentialParams}, differentialResultType);
1099+
auto jvpResultType = TupleType::get(
1100+
{TupleTypeElt(origResultType, Context.Id_value),
1101+
TupleTypeElt(differentialType, Context.Id_differential)}, Context);
1102+
return FunctionType::get(params, jvpResultType)
1103+
->withExtInfo(
1104+
FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws));
1105+
}
1106+
};
1107+
1108+
BuiltinFunctionBuilder::LambdaGenerator vjpGen {
1109+
[=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1110+
SmallVector<FunctionType::Param, 2> params;
1111+
for (auto &paramGen : fnArgGens)
1112+
params.push_back(FunctionType::Param(paramGen.build(builder)));
1113+
auto origResultType = origResultGen.build(builder);
1114+
SmallVector<TupleTypeElt, 2> pullbackResultTupleElts;
1115+
for (auto &param : params) {
1116+
auto tanType = DependentMemberType::get(
1117+
param.getPlainType(), tangentVectorDecl);
1118+
pullbackResultTupleElts.push_back(TupleTypeElt(tanType));
1119+
}
1120+
auto pullbackParam = FunctionType::Param(
1121+
DependentMemberType::get(origResultType, tangentVectorDecl));
1122+
auto pullbackType = FunctionType::get(
1123+
{pullbackParam},
1124+
pullbackResultTupleElts.size() == 1
1125+
? pullbackResultTupleElts.front().getType()
1126+
: TupleType::get(pullbackResultTupleElts, Context));
1127+
auto vjpResultType = TupleType::get(
1128+
{TupleTypeElt(origResultType, Context.Id_value),
1129+
TupleTypeElt(pullbackType, Context.Id_pullback)}, Context);
1130+
return FunctionType::get(params, vjpResultType)
1131+
->withExtInfo(
1132+
FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws));
1133+
}
1134+
};
1135+
1136+
BuiltinFunctionBuilder::LambdaGenerator resultGen {
1137+
[&](BuiltinFunctionBuilder &builder) -> Type {
1138+
auto origFnType = origFnGen.build(builder)->castTo<FunctionType>();
1139+
return origFnType->withExtInfo(
1140+
origFnType->getExtInfo()
1141+
.withDifferentiabilityKind(DifferentiabilityKind::Normal));
1142+
}
1143+
};
1144+
1145+
builder.addParameter(origFnGen, ValueOwnership::Owned);
1146+
builder.addParameter(jvpGen, ValueOwnership::Owned);
1147+
builder.addParameter(vjpGen, ValueOwnership::Owned);
1148+
builder.setResult(resultGen);
1149+
return builder.build(Id);
1150+
}
1151+
1152+
static ValueDecl *getLinearFunctionConstructor(
1153+
ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1154+
assert(arity >= 1);
1155+
unsigned numGenericParams = 1 + arity;
1156+
BuiltinFunctionBuilder builder(Context, numGenericParams);
1157+
// Get the `Differentiable` and `AdditiveArithmetic` protocols.
1158+
auto *diffableProto =
1159+
Context.getProtocol(KnownProtocolKind::Differentiable);
1160+
auto *addArithProto =
1161+
Context.getProtocol(KnownProtocolKind::AdditiveArithmetic);
1162+
// Create type parameters and add conformance constraints.
1163+
auto origResultGen = makeGenericParam(arity);
1164+
builder.addConformanceRequirement(origResultGen, diffableProto);
1165+
builder.addConformanceRequirement(origResultGen, addArithProto);
1166+
SmallVector<decltype(origResultGen), 2> fnArgGens;
1167+
for (auto i : range(arity)) {
1168+
auto T = makeGenericParam(i);
1169+
builder.addConformanceRequirement(T, diffableProto);
1170+
builder.addConformanceRequirement(T, addArithProto);
1171+
fnArgGens.push_back(T);
1172+
}
1173+
1174+
BuiltinFunctionBuilder::LambdaGenerator origFnGen {
1175+
[=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
1176+
SmallVector<FunctionType::Param, 2> params;
1177+
for (auto &paramGen : fnArgGens)
1178+
params.push_back(FunctionType::Param(paramGen.build(builder)));
1179+
return FunctionType::get(params, origResultGen.build(builder))
1180+
->withExtInfo(
1181+
FunctionType::ExtInfo(FunctionTypeRepresentation::Swift, throws));
1182+
}
1183+
};
1184+
1185+
BuiltinFunctionBuilder::LambdaGenerator transposeFnGen {
1186+
[=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1187+
auto origResultType = origResultGen.build(builder);
1188+
SmallVector<TupleTypeElt, 2> resultTupleElts;
1189+
for (auto &paramGen : fnArgGens)
1190+
resultTupleElts.push_back(paramGen.build(builder));
1191+
return FunctionType::get(
1192+
{FunctionType::Param(origResultType)},
1193+
resultTupleElts.size() == 1
1194+
? resultTupleElts.front().getType()
1195+
: TupleType::get(resultTupleElts, Context));
1196+
}
1197+
};
1198+
1199+
BuiltinFunctionBuilder::LambdaGenerator resultGen {
1200+
[&](BuiltinFunctionBuilder &builder) -> Type {
1201+
auto origFnType = origFnGen.build(builder)->castTo<FunctionType>();
1202+
return origFnType->withExtInfo(
1203+
origFnType->getExtInfo()
1204+
.withDifferentiabilityKind(DifferentiabilityKind::Linear));
1205+
}
1206+
};
1207+
1208+
builder.addParameter(origFnGen, ValueOwnership::Owned);
1209+
builder.addParameter(transposeFnGen, ValueOwnership::Owned);
1210+
builder.setResult(resultGen);
1211+
return builder.build(Id);
1212+
}
1213+
10511214
static ValueDecl *getGlobalStringTablePointer(ASTContext &Context,
10521215
Identifier Id) {
10531216
// String -> Builtin.RawPointer
@@ -1839,6 +2002,22 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
18392002
return getAutoDiffApplyDerivativeFunction(Context, Id, kind, arity,
18402003
rethrows);
18412004
}
2005+
if (OperationName.startswith("differentiableFunction_")) {
2006+
unsigned arity;
2007+
bool throws;
2008+
if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
2009+
OperationName, arity, throws))
2010+
return nullptr;
2011+
return getDifferentiableFunctionConstructor(Context, Id, arity, throws);
2012+
}
2013+
if (OperationName.startswith("linearFunction_")) {
2014+
unsigned arity;
2015+
bool throws;
2016+
if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
2017+
OperationName, arity, throws))
2018+
return nullptr;
2019+
return getLinearFunctionConstructor(Context, Id, arity, throws);
2020+
}
18422021
auto BV = llvm::StringSwitch<BuiltinValueKind>(OperationName)
18432022
#define BUILTIN(id, name, Attrs) .Case(name, BuiltinValueKind::id)
18442023
#include "swift/AST/Builtins.def"
@@ -2110,6 +2289,8 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
21102289

21112290
// SWIFT_ENABLE_TENSORFLOW
21122291
case BuiltinValueKind::AutoDiffApply:
2292+
case BuiltinValueKind::DifferentiableFunction:
2293+
case BuiltinValueKind::LinearFunction:
21132294
llvm_unreachable("Handled above");
21142295

21152296
case BuiltinValueKind::OnFastPath:

lib/SIL/SILModule.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,10 @@ const BuiltinInfo &SILModule::getBuiltinInfo(Identifier ID) {
312312
// SWIFT_ENABLE_TENSORFLOW
313313
else if (OperationName.startswith("autodiffApply_"))
314314
Info.ID = BuiltinValueKind::AutoDiffApply;
315+
else if (OperationName.startswith("differentiableFunction_"))
316+
Info.ID = BuiltinValueKind::DifferentiableFunction;
317+
else if (OperationName.startswith("linearFunction_"))
318+
Info.ID = BuiltinValueKind::LinearFunction;
315319
else
316320
Info.ID = llvm::StringSwitch<BuiltinValueKind>(OperationName)
317321
#define BUILTIN(id, name, attrs) .Case(name, BuiltinValueKind::id)

lib/SILGen/SILGenBuiltin.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,37 @@ static ManagedValue emitBuiltinAutoDiffApply(SILGenFunction &SGF,
11541154
substitutions, args, C);
11551155
}
11561156

1157+
static ManagedValue emitBuiltinDifferentiableFunction(
1158+
SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions,
1159+
ArrayRef<ManagedValue> args, SGFContext C) {
1160+
assert(args.size() == 3);
1161+
auto origFn = args.front();
1162+
auto origType = origFn.getType().castTo<SILFunctionType>();
1163+
auto diffFn = SGF.B.createDifferentiableFunction(
1164+
loc,
1165+
IndexSubset::getDefault(
1166+
SGF.getASTContext(), origType->getNumParameters(),
1167+
/*includeAll*/ true),
1168+
origFn.forward(SGF),
1169+
std::make_pair(args[1].forward(SGF), args[2].forward(SGF)));
1170+
return SGF.emitManagedRValueWithCleanup(diffFn);
1171+
}
1172+
1173+
static ManagedValue emitBuiltinLinearFunction(
1174+
SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions,
1175+
ArrayRef<ManagedValue> args, SGFContext C) {
1176+
assert(args.size() == 2);
1177+
auto origFn = args.front();
1178+
auto origType = origFn.getType().castTo<SILFunctionType>();
1179+
auto linearFn = SGF.B.createLinearFunction(
1180+
loc,
1181+
IndexSubset::getDefault(
1182+
SGF.getASTContext(), origType->getNumParameters(),
1183+
/*includeAll*/ true),
1184+
origFn.forward(SGF), args[1].forward(SGF));
1185+
return SGF.emitManagedRValueWithCleanup(linearFn);
1186+
}
1187+
11571188
/// Emit SIL for the named builtin: globalStringTablePointer. Unlike the default
11581189
/// ownership convention for named builtins, which is to take (non-trivial)
11591190
/// arguments as Owned, this builtin accepts owned as well as guaranteed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8961,11 +8961,15 @@ void Differentiation::run() {
89618961
for (SILInstruction &i : bb) {
89628962
if (auto *dfi = dyn_cast<DifferentiableFunctionInst>(&i))
89638963
context.getDifferentiableFunctionInsts().push_back(dfi);
8964+
// Reject uncanonical `linear_function` instructions.
8965+
// FIXME(SR-11850): Add support for linear map transposition.
89648966
else if (auto *lfi = dyn_cast<LinearFunctionInst>(&i)) {
8965-
astCtx.Diags.diagnose(
8966-
lfi->getLoc().getSourceLoc(),
8967-
diag::autodiff_conversion_to_linear_function_not_supported);
8968-
errorOccurred = true;
8967+
if (!lfi->hasTransposeFunction()) {
8968+
astCtx.Diags.diagnose(
8969+
lfi->getLoc().getSourceLoc(),
8970+
diag::autodiff_conversion_to_linear_function_not_supported);
8971+
errorOccurred = true;
8972+
}
89698973
}
89708974
}
89718975
}

0 commit comments

Comments
 (0)