Skip to content

Commit 025cb9a

Browse files
author
marcrasi
authored
autodiff builtins (#30624)
Define type signatures and SILGen for the following builtins: ``` /// Applies the {jvp|vjp} of `f` to `arg1`, ..., `argN`. func applyDerivative_arityN_{jvp|vjp}(f, arg1, ..., argN) -> jvp/vjp return type /// Applies the transpose of `f` to `arg`. func applyTranspose_arityN(f, arg) -> transpose return type /// Makes a differentiable function from the given `original`, `jvp`, and /// `vjp` functions. func differentiableFunction_arityN(original, jvp, vjp) /// Makes a linear function from the given `original` and `transpose` functions. func linearFunction_arityN(original, transpose) ``` Add SILGen FileCheck tests for all builtins.
1 parent 972de2e commit 025cb9a

File tree

13 files changed

+838
-3
lines changed

13 files changed

+838
-3
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,33 @@ GenericSignature getConstrainedDerivativeGenericSignature(
441441
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
442442
bool isTranspose = false);
443443

444+
/// Retrieve config from the function name of a variant of
445+
/// `Builtin.applyDerivative`, e.g. `Builtin.applyDerivative_jvp_arity2`.
446+
/// Returns true if the function name is parsed successfully.
447+
bool getBuiltinApplyDerivativeConfig(
448+
StringRef operationName, AutoDiffDerivativeFunctionKind &kind,
449+
unsigned &arity, bool &rethrows);
450+
451+
/// Retrieve config from the function name of a variant of
452+
/// `Builtin.applyTranspose`, e.g. `Builtin.applyTranspose_arity2`.
453+
/// Returns true if the function name is parsed successfully.
454+
bool getBuiltinApplyTransposeConfig(
455+
StringRef operationName, unsigned &arity, bool &rethrows);
456+
457+
/// Retrieve config from the function name of a variant of
458+
/// `Builtin.differentiableFunction` or `Builtin.linearFunction`, e.g.
459+
/// `Builtin.differentiableFunction_arity1_throws`.
460+
/// Returns true if the function name is parsed successfully.
461+
bool getBuiltinDifferentiableOrLinearFunctionConfig(
462+
StringRef operationName, unsigned &arity, bool &throws);
463+
464+
/// Retrieve config from the function name of a variant of
465+
/// `Builtin.differentiableFunction` or `Builtin.linearFunction`, e.g.
466+
/// `Builtin.differentiableFunction_arity1_throws`.
467+
/// Returns true if the function name is parsed successfully.
468+
bool getBuiltinDifferentiableOrLinearFunctionConfig(
469+
StringRef operationName, unsigned &arity, bool &throws);
470+
444471
} // end namespace autodiff
445472

446473
} // end namespace swift

include/swift/AST/Builtins.def

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,18 @@ BUILTIN_SIL_OPERATION(ConvertStrongToUnownedUnsafe, "convertStrongToUnownedUnsaf
469469
/// now.
470470
BUILTIN_SIL_OPERATION(ConvertUnownedUnsafeToGuaranteed, "convertUnownedUnsafeToGuaranteed", Special)
471471

472+
/// applyDerivative
473+
BUILTIN_SIL_OPERATION(ApplyDerivative, "applyDerivative", Special)
474+
475+
/// applyTranspose
476+
BUILTIN_SIL_OPERATION(ApplyTranspose, "applyTranspose", Special)
477+
478+
/// differentiableFunction
479+
BUILTIN_SIL_OPERATION(DifferentiableFunction, "differentiableFunction", Special)
480+
481+
/// linearFunction
482+
BUILTIN_SIL_OPERATION(LinearFunction, "linearFunction", Special)
483+
472484
#undef BUILTIN_SIL_OPERATION
473485

474486
// BUILTIN_RUNTIME_CALL - A call into a runtime function.

include/swift/AST/KnownIdentifiers.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ IDENTIFIER(withKeywordArguments)
133133
IDENTIFIER(wrapped)
134134
IDENTIFIER(wrappedValue)
135135
IDENTIFIER(wrapperValue)
136+
IDENTIFIER(differential)
137+
IDENTIFIER(pullback)
136138

137139
// Kinds of layout constraints
138140
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ PROTOCOL_(SwiftNewtypeWrapper)
7878
PROTOCOL(CodingKey)
7979
PROTOCOL(Encodable)
8080
PROTOCOL(Decodable)
81+
PROTOCOL(AdditiveArithmetic)
8182

8283
PROTOCOL_(ObjectiveCBridgeable)
8384
PROTOCOL_(DestructorSafeContainer)

include/swift/AST/Types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3353,6 +3353,8 @@ class AnyFunctionType : public TypeBase {
33533353
IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
33543354
LookupConformanceFn lookupConformance, bool makeSelfParamFirst = false);
33553355

3356+
AnyFunctionType *getWithoutDifferentiability() const;
3357+
33563358
/// True if the parameter declaration it is attached to is guaranteed
33573359
/// to not persist the closure for longer than the duration of the call.
33583360
bool isNoEscape() const {

lib/AST/AutoDiff.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,77 @@ GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
279279
nullptr);
280280
}
281281

282+
// Given the rest of a `Builtin.applyDerivative_{jvp|vjp}` or
283+
// `Builtin.applyTranspose` operation name, attempts to parse the arity and
284+
// throwing-ness from the operation name. Modifies the operation name argument
285+
// in place as substrings get dropped.
286+
static void parseAutoDiffBuiltinCommonConfig(
287+
StringRef &operationName, unsigned &arity, bool &throws) {
288+
// Parse '_arity'.
289+
constexpr char arityPrefix[] = "_arity";
290+
if (operationName.startswith(arityPrefix)) {
291+
operationName = operationName.drop_front(sizeof(arityPrefix) - 1);
292+
auto arityStr = operationName.take_while(llvm::isDigit);
293+
operationName = operationName.drop_front(arityStr.size());
294+
auto converted = llvm::to_integer(arityStr, arity);
295+
assert(converted); (void)converted;
296+
assert(arity > 0);
297+
} else {
298+
arity = 1;
299+
}
300+
// Parse '_throws'.
301+
constexpr char throwsPrefix[] = "_throws";
302+
if (operationName.startswith(throwsPrefix)) {
303+
operationName = operationName.drop_front(sizeof(throwsPrefix) - 1);
304+
throws = true;
305+
} else {
306+
throws = false;
307+
}
308+
}
309+
310+
bool autodiff::getBuiltinApplyDerivativeConfig(
311+
StringRef operationName, AutoDiffDerivativeFunctionKind &kind,
312+
unsigned &arity, bool &throws) {
313+
constexpr char prefix[] = "applyDerivative";
314+
if (!operationName.startswith(prefix))
315+
return false;
316+
operationName = operationName.drop_front(sizeof(prefix) - 1);
317+
// Parse 'jvp' or 'vjp'.
318+
constexpr char jvpPrefix[] = "_jvp";
319+
constexpr char vjpPrefix[] = "_vjp";
320+
if (operationName.startswith(jvpPrefix))
321+
kind = AutoDiffDerivativeFunctionKind::JVP;
322+
else if (operationName.startswith(vjpPrefix))
323+
kind = AutoDiffDerivativeFunctionKind::VJP;
324+
operationName = operationName.drop_front(sizeof(jvpPrefix) - 1);
325+
parseAutoDiffBuiltinCommonConfig(operationName, arity, throws);
326+
return operationName.empty();
327+
}
328+
329+
bool autodiff::getBuiltinApplyTransposeConfig(
330+
StringRef operationName, unsigned &arity, bool &throws) {
331+
constexpr char prefix[] = "applyTranspose";
332+
if (!operationName.startswith(prefix))
333+
return false;
334+
operationName = operationName.drop_front(sizeof(prefix) - 1);
335+
parseAutoDiffBuiltinCommonConfig(operationName, arity, throws);
336+
return operationName.empty();
337+
}
338+
339+
bool autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
340+
StringRef operationName, unsigned &arity, bool &throws) {
341+
constexpr char differentiablePrefix[] = "differentiableFunction";
342+
constexpr char linearPrefix[] = "linearFunction";
343+
if (operationName.startswith(differentiablePrefix))
344+
operationName = operationName.drop_front(sizeof(differentiablePrefix) - 1);
345+
else if (operationName.startswith(linearPrefix))
346+
operationName = operationName.drop_front(sizeof(linearPrefix) - 1);
347+
else
348+
return false;
349+
parseAutoDiffBuiltinCommonConfig(operationName, arity, throws);
350+
return operationName.empty();
351+
}
352+
282353
Type TangentSpace::getType() const {
283354
switch (kind) {
284355
case Kind::TangentVector:

0 commit comments

Comments
 (0)