Skip to content

Commit 9853570

Browse files
authored
[AutoDiff] Add 'Builtin.applyTranspose*'. (#28469)
Add support for applying the transpose function in a `@differentiable(linear)` function. The `Builtin.applyTranspose*` builtin takes a `@differentiable(linear)` function and a tangent vector, and returns the result of applying the transpose to the tangent vector. Pseudo-declaration: ```swift func applyTranspose_arity{arity}[_throws?{r}]<T...{arity}, R>( _: @differentiable (T...) {r}?throws -> R, _: R ) {r}?rethrows -> (T...) where T: Differentiable & AdditiveArithmetic, R: Differentiable & AdditiveArithmetic ``` This patch also renames `Builtin.autodiffApply` to `Builtin.applyDerivative` for clarity, and fixes a bug in `LinearDifferentiableSILFunctionTypeLowering` where it expected 3 component values in `rebuildAggregate` instead of 2. Resolves SR-11844 and SR-11851.
1 parent e90816b commit 9853570

12 files changed

+273
-104
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,11 +352,24 @@ IndexSubset *getLoweredParameterIndices(IndexSubset *indices,
352352
AnyFunctionType *type);
353353

354354
/// Retrieve config from the function name of a variant of
355-
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2`.
355+
/// `Builtin.applyDerivative`, e.g. `Builtin.applyDerivative_jvp_arity2`.
356356
/// Returns true if the function name is parsed successfully.
357-
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
358-
AutoDiffDerivativeFunctionKind &kind,
359-
unsigned &arity, bool &rethrows);
357+
bool getBuiltinApplyDerivativeConfig(
358+
StringRef operationName, AutoDiffDerivativeFunctionKind &kind,
359+
unsigned &arity, bool &rethrows);
360+
361+
/// Retrieve config from the function name of a variant of
362+
/// `Builtin.applyTranspose`, e.g. `Builtin.applyTranspose_arity2`.
363+
/// Returns true if the function name is parsed successfully.
364+
bool getBuiltinApplyTransposeConfig(
365+
StringRef operationName, unsigned &arity, bool &rethrows);
366+
367+
/// Retrieve config from the function name of a variant of
368+
/// `Builtin.differentiableFunction` or `Builtin.linearFunction`, e.g.
369+
/// `Builtin.differentiableFunction_arity1_throws`.
370+
/// Returns true if the function name is parsed successfully.
371+
bool getBuiltinDifferentiableOrLinearFunctionConfig(
372+
StringRef operationName, unsigned &arity, bool &throws);
360373

361374
/// Retrieve config from the function name of a variant of
362375
/// `Builtin.differentiableFunction` or `Builtin.linearFunction`, e.g.

include/swift/AST/Builtins.def

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,11 @@ BUILTIN_SIL_OPERATION(AllocWithTailElems, "allocWithTailElems", Special)
445445
BUILTIN_SIL_OPERATION(ProjectTailElems, "projectTailElems", Special)
446446

447447
// SWIFT_ENABLE_TENSORFLOW
448-
/// autodiffApply
449-
BUILTIN_SIL_OPERATION(AutoDiffApply, "autodiffApply", Special)
448+
/// applyDerivative
449+
BUILTIN_SIL_OPERATION(ApplyDerivative, "applyDerivative", Special)
450+
451+
/// applyTranspose
452+
BUILTIN_SIL_OPERATION(ApplyTranspose, "applyTranspose", Special)
450453

451454
/// differentiableFunction
452455
BUILTIN_SIL_OPERATION(DifferentiableFunction, "differentiableFunction", Special)

lib/AST/AutoDiff.cpp

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -206,21 +206,12 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
206206
}
207207
}
208208

209-
bool autodiff::getBuiltinAutoDiffApplyConfig(
210-
StringRef operationName, AutoDiffDerivativeFunctionKind &kind,
211-
unsigned &arity, bool &rethrows) {
212-
constexpr char prefix[] = "autodiffApply";
213-
if (!operationName.startswith(prefix))
214-
return false;
215-
operationName = operationName.drop_front(sizeof(prefix) - 1);
216-
// Parse 'jvp' or 'vjp'.
217-
constexpr char jvpPrefix[] = "_jvp";
218-
constexpr char vjpPrefix[] = "_vjp";
219-
if (operationName.startswith(jvpPrefix))
220-
kind = AutoDiffDerivativeFunctionKind::JVP;
221-
else if (operationName.startswith(vjpPrefix))
222-
kind = AutoDiffDerivativeFunctionKind::VJP;
223-
operationName = operationName.drop_front(sizeof(jvpPrefix) - 1);
209+
// Given the rest of a `Builtin.applyDerivative_{jvp|vjp}` or
210+
// `Builtin.applyTranspose` operation name, attempts to parse the arity and
211+
// throwing-ness from the operation name. Modifies the operation name argument
212+
// in place as substrings get dropped.
213+
static void parseAutoDiffBuiltinCommonConfig(
214+
StringRef &operationName, unsigned &arity, bool &throws) {
224215
// Parse '_arity'.
225216
constexpr char arityPrefix[] = "_arity";
226217
if (operationName.startswith(arityPrefix)) {
@@ -233,14 +224,42 @@ bool autodiff::getBuiltinAutoDiffApplyConfig(
233224
} else {
234225
arity = 1;
235226
}
236-
// Parse '_rethrows'.
237-
constexpr char rethrowsPrefix[] = "_rethrows";
238-
if (operationName.startswith(rethrowsPrefix)) {
239-
operationName = operationName.drop_front(sizeof(rethrowsPrefix) - 1);
240-
rethrows = true;
227+
// Parse '_throws'.
228+
constexpr char throwsPrefix[] = "_throws";
229+
if (operationName.startswith(throwsPrefix)) {
230+
operationName = operationName.drop_front(sizeof(throwsPrefix) - 1);
231+
throws = true;
241232
} else {
242-
rethrows = false;
233+
throws = false;
243234
}
235+
}
236+
237+
bool autodiff::getBuiltinApplyDerivativeConfig(
238+
StringRef operationName, AutoDiffDerivativeFunctionKind &kind,
239+
unsigned &arity, bool &throws) {
240+
constexpr char prefix[] = "applyDerivative";
241+
if (!operationName.startswith(prefix))
242+
return false;
243+
operationName = operationName.drop_front(sizeof(prefix) - 1);
244+
// Parse 'jvp' or 'vjp'.
245+
constexpr char jvpPrefix[] = "_jvp";
246+
constexpr char vjpPrefix[] = "_vjp";
247+
if (operationName.startswith(jvpPrefix))
248+
kind = AutoDiffDerivativeFunctionKind::JVP;
249+
else if (operationName.startswith(vjpPrefix))
250+
kind = AutoDiffDerivativeFunctionKind::VJP;
251+
operationName = operationName.drop_front(sizeof(jvpPrefix) - 1);
252+
parseAutoDiffBuiltinCommonConfig(operationName, arity, throws);
253+
return operationName.empty();
254+
}
255+
256+
bool autodiff::getBuiltinApplyTransposeConfig(
257+
StringRef operationName, unsigned &arity, bool &throws) {
258+
constexpr char prefix[] = "applyTranspose";
259+
if (!operationName.startswith(prefix))
260+
return false;
261+
operationName = operationName.drop_front(sizeof(prefix) - 1);
262+
parseAutoDiffBuiltinCommonConfig(operationName, arity, throws);
244263
return operationName.empty();
245264
}
246265

@@ -254,26 +273,7 @@ bool autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
254273
operationName = operationName.drop_front(sizeof(linearPrefix) - 1);
255274
else
256275
return false;
257-
// Parse '_arity'.
258-
constexpr char arityPrefix[] = "_arity";
259-
if (operationName.startswith(arityPrefix)) {
260-
operationName = operationName.drop_front(sizeof(arityPrefix) - 1);
261-
auto arityStr = operationName.take_while(llvm::isDigit);
262-
operationName = operationName.drop_front(arityStr.size());
263-
auto converted = llvm::to_integer(arityStr, arity);
264-
assert(converted); (void)converted;
265-
assert(arity > 0);
266-
} else {
267-
arity = 1;
268-
}
269-
// Parse '_throws'.
270-
constexpr char throwsPrefix[] = "_throws";
271-
if (operationName.startswith(throwsPrefix)) {
272-
operationName = operationName.drop_front(sizeof(throwsPrefix) - 1);
273-
throws = true;
274-
} else {
275-
throws = false;
276-
}
276+
parseAutoDiffBuiltinCommonConfig(operationName, arity, throws);
277277
return operationName.empty();
278278
}
279279

lib/AST/Builtins.cpp

Lines changed: 90 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ static ValueDecl *getGetObjCTypeEncodingOperation(ASTContext &Context,
985985
// SWIFT_ENABLE_TENSORFLOW
986986
static ValueDecl *getAutoDiffApplyDerivativeFunction(
987987
ASTContext &Context, Identifier Id, AutoDiffDerivativeFunctionKind kind,
988-
unsigned arity, bool rethrows) {
988+
unsigned arity, bool throws) {
989989
assert(arity >= 1);
990990
// JVP:
991991
// <...T...(arity), R> (@differentiable (...T) throws -> R, ...T)
@@ -1000,54 +1000,114 @@ static ValueDecl *getAutoDiffApplyDerivativeFunction(
10001000
// Create type parameters and add conformance constraints.
10011001
auto fnResultGen = makeGenericParam(arity);
10021002
builder.addConformanceRequirement(fnResultGen, diffableProto);
1003-
SmallVector<decltype(fnResultGen), 2> fnArgGens;
1003+
SmallVector<decltype(fnResultGen), 2> fnParamGens;
10041004
for (auto i : range(arity)) {
10051005
auto T = makeGenericParam(i);
10061006
builder.addConformanceRequirement(T, diffableProto);
1007-
fnArgGens.push_back(T);
1007+
fnParamGens.push_back(T);
10081008
}
1009-
// Generator for the first argument, i.e. the @differentiable function.
1009+
// Generator for the first argument, i.e. the `@differentiable` function.
10101010
BuiltinFunctionBuilder::LambdaGenerator firstArgGen {
10111011
// Generator for the function type at the argument position, i.e. the
10121012
// function being differentiated.
1013-
[=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
1013+
[=, &fnParamGens](BuiltinFunctionBuilder &builder) -> Type {
10141014
FunctionType::ExtInfo ext;
10151015
auto extInfo = FunctionType::ExtInfo()
10161016
.withDifferentiabilityKind(DifferentiabilityKind::Normal)
1017-
.withNoEscape().withThrows(rethrows);
1017+
.withNoEscape().withThrows(throws);
10181018
SmallVector<FunctionType::Param, 2> params;
1019-
for (auto &paramGen : fnArgGens)
1019+
for (auto &paramGen : fnParamGens)
10201020
params.push_back(FunctionType::Param(paramGen.build(builder)));
10211021
auto innerFunction = FunctionType::get(params,
10221022
fnResultGen.build(builder));
10231023
return innerFunction->withExtInfo(extInfo);
10241024
}
10251025
};
10261026
// Eagerly build the type of the first arg, then use that to compute the type
1027-
// of the derivative function type.
1028-
auto *origFnTy =
1027+
// of the result.
1028+
auto *diffFnType =
10291029
firstArgGen.build(builder)->castTo<AnyFunctionType>();
1030-
origFnTy = origFnTy->getWithoutDifferentiability()->withExtInfo(
1031-
origFnTy->getExtInfo().withNoEscape(false));
1030+
diffFnType = diffFnType->getWithoutDifferentiability()->withExtInfo(
1031+
diffFnType->getExtInfo().withNoEscape(false));
10321032
auto *paramIndices = IndexSubset::get(
1033-
Context, SmallBitVector(origFnTy->getNumParams(), true));
1033+
Context, SmallBitVector(diffFnType->getNumParams(), true));
10341034
// Generator for the resultant function type, i.e. the AD derivative function.
10351035
BuiltinFunctionBuilder::LambdaGenerator resultGen{
10361036
[=, &Context](BuiltinFunctionBuilder &builder) -> Type {
1037-
auto derivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType(
1037+
auto derivativeFnTy = diffFnType->getAutoDiffDerivativeFunctionType(
10381038
paramIndices, /*resultIndex*/ 0, kind,
10391039
LookUpConformanceInModule(Context.TheBuiltinModule));
10401040
return derivativeFnTy->getResult();
10411041
}};
10421042
builder.addParameter(firstArgGen);
1043-
for (auto argGen : fnArgGens)
1043+
for (auto argGen : fnParamGens)
10441044
builder.addParameter(argGen);
1045-
if (rethrows)
1045+
if (throws)
10461046
builder.setRethrows();
10471047
builder.setResult(resultGen);
10481048
return builder.build(Id);
10491049
}
10501050

1051+
static ValueDecl *getAutoDiffApplyTransposeFunction(
1052+
ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1053+
assert(arity >= 1);
1054+
// <...T...(arity), R>
1055+
// (@differentiable (...T) throws -> R, ...R.TangentVector)
1056+
// rethrows -> (...T.TangentVector)
1057+
unsigned numGenericParams = 1 + arity;
1058+
BuiltinFunctionBuilder builder(Context, numGenericParams);
1059+
auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
1060+
auto *addArithProto =
1061+
Context.getProtocol(KnownProtocolKind::AdditiveArithmetic);
1062+
// Create type parameters and add conformance constraints.
1063+
auto linearFnResultGen = makeGenericParam(arity);
1064+
builder.addConformanceRequirement(linearFnResultGen, diffableProto);
1065+
builder.addConformanceRequirement(linearFnResultGen, addArithProto);
1066+
SmallVector<decltype(linearFnResultGen), 2> linearFnParamGens;
1067+
for (auto i : range(arity)) {
1068+
auto T = makeGenericParam(i);
1069+
builder.addConformanceRequirement(T, diffableProto);
1070+
builder.addConformanceRequirement(T, addArithProto);
1071+
linearFnParamGens.push_back(T);
1072+
}
1073+
// Generator for the first argument, i.e. the `@differentiable(linear)`
1074+
// function.
1075+
BuiltinFunctionBuilder::LambdaGenerator firstArgGen {
1076+
// Generator for the function type at the argument position, i.e. the
1077+
// function being differentiated.
1078+
[=, &linearFnParamGens](BuiltinFunctionBuilder &builder) -> Type {
1079+
FunctionType::ExtInfo ext;
1080+
auto extInfo = FunctionType::ExtInfo()
1081+
.withDifferentiabilityKind(DifferentiabilityKind::Linear)
1082+
.withNoEscape().withThrows(throws);
1083+
SmallVector<FunctionType::Param, 2> params;
1084+
for (auto &paramGen : linearFnParamGens)
1085+
params.push_back(FunctionType::Param(paramGen.build(builder)));
1086+
auto innerFunction = FunctionType::get(params,
1087+
linearFnResultGen.build(builder));
1088+
return innerFunction->withExtInfo(extInfo);
1089+
}
1090+
};
1091+
builder.addParameter(firstArgGen);
1092+
builder.addParameter(linearFnResultGen);
1093+
if (throws)
1094+
builder.setRethrows();
1095+
if (arity == 1)
1096+
builder.setResult(linearFnParamGens.front());
1097+
else {
1098+
BuiltinFunctionBuilder::LambdaGenerator tupleResultGen {
1099+
[&](BuiltinFunctionBuilder &builder) -> Type {
1100+
SmallVector<TupleTypeElt, 2> tupleElts;
1101+
for (auto linearFnParamGen : linearFnParamGens)
1102+
tupleElts.push_back(linearFnParamGen.build(builder));
1103+
return TupleType::get(tupleElts, Context);
1104+
}
1105+
};
1106+
builder.setResult(tupleResultGen);
1107+
}
1108+
return builder.build(Id);
1109+
}
1110+
10511111
static ValueDecl *getDifferentiableFunctionConstructor(
10521112
ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
10531113
assert(arity >= 1);
@@ -1992,15 +2052,23 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
19922052
return getAllocWithTailElemsOperation(Context, Id, NumTailTypes);
19932053
}
19942054
// SWIFT_ENABLE_TENSORFLOW
1995-
if (OperationName.startswith("autodiffApply_")) {
2055+
if (OperationName.startswith("applyDerivative_")) {
19962056
AutoDiffDerivativeFunctionKind kind;
19972057
unsigned arity;
1998-
bool rethrows;
1999-
if (!autodiff::getBuiltinAutoDiffApplyConfig(OperationName, kind, arity,
2000-
rethrows))
2058+
bool throws;
2059+
if (!autodiff::getBuiltinApplyDerivativeConfig(
2060+
OperationName, kind, arity, throws))
20012061
return nullptr;
20022062
return getAutoDiffApplyDerivativeFunction(Context, Id, kind, arity,
2003-
rethrows);
2063+
throws);
2064+
}
2065+
if (OperationName.startswith("applyTranspose_")) {
2066+
unsigned arity;
2067+
bool throws;
2068+
if (!autodiff::getBuiltinApplyTransposeConfig(
2069+
OperationName, arity, throws))
2070+
return nullptr;
2071+
return getAutoDiffApplyTransposeFunction(Context, Id, arity, throws);
20042072
}
20052073
if (OperationName.startswith("differentiableFunction_")) {
20062074
unsigned arity;
@@ -2288,7 +2356,8 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
22882356
return getUnsafeGuaranteedEnd(Context, Id);
22892357

22902358
// SWIFT_ENABLE_TENSORFLOW
2291-
case BuiltinValueKind::AutoDiffApply:
2359+
case BuiltinValueKind::ApplyDerivative:
2360+
case BuiltinValueKind::ApplyTranspose:
22922361
case BuiltinValueKind::DifferentiableFunction:
22932362
case BuiltinValueKind::LinearFunction:
22942363
llvm_unreachable("Handled above");

lib/SIL/SILModule.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,10 @@ const BuiltinInfo &SILModule::getBuiltinInfo(Identifier ID) {
310310
else if (OperationName.startswith("allocWithTailElems_"))
311311
Info.ID = BuiltinValueKind::AllocWithTailElems;
312312
// SWIFT_ENABLE_TENSORFLOW
313-
else if (OperationName.startswith("autodiffApply_"))
314-
Info.ID = BuiltinValueKind::AutoDiffApply;
313+
else if (OperationName.startswith("applyDerivative_"))
314+
Info.ID = BuiltinValueKind::ApplyDerivative;
315+
else if (OperationName.startswith("applyTranspose_"))
316+
Info.ID = BuiltinValueKind::ApplyTranspose;
315317
else if (OperationName.startswith("differentiableFunction_"))
316318
Info.ID = BuiltinValueKind::DifferentiableFunction;
317319
else if (OperationName.startswith("linearFunction_"))

lib/SIL/TypeLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ namespace {
962962

963963
SILValue rebuildAggregate(SILBuilder &B, SILLocation loc,
964964
ArrayRef<SILValue> values) const override {
965-
assert(values.size() == 3);
965+
assert(values.size() == 2);
966966
auto fnTy = getLoweredType().castTo<SILFunctionType>();
967967
auto paramIndices = fnTy->getDifferentiationParameterIndices();
968968
return B.createLinearFunction(loc, paramIndices, values[0], values[1]);

0 commit comments

Comments
 (0)