Skip to content

Commit af8942d

Browse files
committed
[AutoDiff] Rename '@differentiable' to '@differentiable(reverse)'.
Compiler: - Add `Forward` and `Reverse` to `DifferentiabilityKind`. - Expand `DifferentiabilityMask` in `ExtInfo` to 3 bits so that it now holds all 4 cases of `DifferentiabilityKind`. - Parse `@differentiable(reverse)` and `@differentiable(_forward)` declaration attributes and type attributes. - Emit a warning for `@differentiable` without `reverse`. - Emit an error for `@differentiable(_forward)`. - Rename `@differentiable(linear)` to `@differentiable(_linear)`. - Make `@differentiable(reverse)` type lowering go through today's `@differentiable` code path. We will specialize it to reverse-mode in a follow-up patch. ABI: - Add `Forward` and `Reverse` to `FunctionMetadataDifferentiabilityKind`. - Extend `TargetFunctionTypeFlags` by 1 bit to store the highest bit of differentiability kind (linear). Note that there is a 2-bit gap in `DifferentiabilityMask` which is reserved for `AsyncMask` and `ConcurrentMask`; `AsyncMask` is ABI-stable so we cannot change that. _Differentiation module: - Replace all occurrences of `@differentiable` with `@differentiable(reverse)`. - Delete `_transpose(of:)`. Resolves rdar://69980056.
1 parent ce587f0 commit af8942d

File tree

149 files changed

+1673
-1535
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

149 files changed

+1673
-1535
lines changed

docs/ABI/Mangling.rst

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,8 @@ Types
550550
FUNCTION-KIND ::= 'E' // function type (noescape)
551551
FUNCTION-KIND ::= 'F' // @differentiable function type
552552
FUNCTION-KIND ::= 'G' // @differentiable function type (escaping)
553-
FUNCTION-KIND ::= 'H' // @differentiable(linear) function type
554-
FUNCTION-KIND ::= 'I' // @differentiable(linear) function type (escaping)
553+
FUNCTION-KIND ::= 'H' // @differentiable(_linear) function type
554+
FUNCTION-KIND ::= 'I' // @differentiable(_linear) function type (escaping)
555555

556556
C-TYPE is mangled according to the Itanium ABI, and prefixed with the length.
557557
Non-ASCII identifiers are preserved as-is; we do not use Punycode.
@@ -633,9 +633,10 @@ mangled in to disambiguate.
633633

634634
CALLEE-ESCAPE ::= 'e' // @escaping (inverse of SIL @noescape)
635635

636-
DIFFERENTIABILITY-KIND ::= DIFFERENTIABLE | LINEAR
637-
DIFFERENTIABLE ::= 'd' // @differentiable
638-
LINEAR ::= 'l' // @differentiable(linear)
636+
DIFFERENTIABILITY-KIND ::= 'd' // @differentiable
637+
DIFFERENTIABILITY-KIND ::= 'l' // @differentiable(_linear)
638+
DIFFERENTIABILITY-KIND ::= 'f' // @differentiable(_forward)
639+
DIFFERENTIABILITY-KIND ::= 'r' // @differentiable(reverse)
639640

640641
CALLEE-CONVENTION ::= 'y' // @callee_unowned
641642
CALLEE-CONVENTION ::= 'g' // @callee_guaranteed

docs/SIL.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7002,7 +7002,7 @@ linear_function
70027002
linear_function [parameters 0] %0 : $(T) -> T with_transpose %1 : $(T) -> T
70037003

70047004
Bundles a function with its transpose function into a
7005-
``@differentiable(linear)`` function.
7005+
``@differentiable(_linear)`` function.
70067006

70077007
``[parameters ...]`` specifies parameter indices that the original function is
70087008
linear with respect to.
@@ -7051,11 +7051,11 @@ linear_function_extract
70517051

70527052
sil-linear-function-extractee ::= 'original' | 'transpose'
70537053

7054-
linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
7055-
linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T
7054+
linear_function_extract [original] %0 : $@differentiable(_linear) (T) -> T
7055+
linear_function_extract [transpose] %0 : $@differentiable(_linear) (T) -> T
70567056

70577057
Extracts the original function or a transpose function from the given
7058-
``@differentiable(linear)`` function. The extractee is one of the following:
7058+
``@differentiable(_linear)`` function. The extractee is one of the following:
70597059
``[original]`` or ``[transpose]``.
70607060

70617061

include/swift/ABI/MetadataValues.h

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -756,11 +756,13 @@ enum class FunctionMetadataConvention: uint8_t {
756756
};
757757

758758
/// Differentiability kind for function type metadata.
759-
/// Duplicates `DifferentiabilityKind` in AutoDiff.h.
759+
/// Duplicates `DifferentiabilityKind` in AST/AutoDiff.h.
760760
enum class FunctionMetadataDifferentiabilityKind: uint8_t {
761-
NonDifferentiable = 0b00,
762-
Normal = 0b01,
763-
Linear = 0b11
761+
NonDifferentiable = 0b00000,
762+
Forward = 0b00001,
763+
Reverse = 0b00010,
764+
Normal = 0b00011,
765+
Linear = 0b10000,
764766
};
765767

766768
/// Flags in a function type metadata record.
@@ -770,16 +772,16 @@ class TargetFunctionTypeFlags {
770772
// one of the flag bits could be used to identify that the rest of
771773
// the flags is going to be stored somewhere else in the metadata.
772774
enum : int_type {
773-
NumParametersMask = 0x0000FFFFU,
774-
ConventionMask = 0x00FF0000U,
775-
ConventionShift = 16U,
776-
ThrowsMask = 0x01000000U,
777-
ParamFlagsMask = 0x02000000U,
778-
EscapingMask = 0x04000000U,
779-
DifferentiableMask = 0x08000000U,
780-
LinearMask = 0x10000000U,
781-
AsyncMask = 0x20000000U,
782-
ConcurrentMask = 0x40000000U,
775+
NumParametersMask = 0x0000FFFFU,
776+
ConventionMask = 0x00FF0000U,
777+
ConventionShift = 16U,
778+
ThrowsMask = 0x01000000U,
779+
ParamFlagsMask = 0x02000000U,
780+
EscapingMask = 0x04000000U,
781+
DifferentiabilityMask = 0x98000000U,
782+
DifferentiabilityShift = 27U,
783+
AsyncMask = 0x20000000U,
784+
ConcurrentMask = 0x40000000U,
783785
};
784786
int_type Data;
785787

@@ -811,13 +813,9 @@ class TargetFunctionTypeFlags {
811813
}
812814

813815
constexpr TargetFunctionTypeFlags<int_type> withDifferentiabilityKind(
814-
FunctionMetadataDifferentiabilityKind differentiability) const {
815-
return TargetFunctionTypeFlags<int_type>(
816-
(Data & ~DifferentiableMask & ~LinearMask) |
817-
(differentiability == FunctionMetadataDifferentiabilityKind::Normal
818-
? DifferentiableMask : 0) |
819-
(differentiability == FunctionMetadataDifferentiabilityKind::Linear
820-
? LinearMask : 0));
816+
FunctionMetadataDifferentiabilityKind differentiabilityKind) const {
817+
return TargetFunctionTypeFlags((Data & ~DifferentiabilityMask)
818+
| (int_type(differentiabilityKind) << DifferentiabilityShift));
821819
}
822820

823821
constexpr TargetFunctionTypeFlags<int_type>
@@ -860,16 +858,13 @@ class TargetFunctionTypeFlags {
860858
bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }
861859

862860
bool isDifferentiable() const {
863-
return getDifferentiabilityKind() >=
864-
FunctionMetadataDifferentiabilityKind::Normal;
861+
return getDifferentiabilityKind() !=
862+
FunctionMetadataDifferentiabilityKind::NonDifferentiable;
865863
}
866864

867865
FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
868-
if (bool(Data & DifferentiableMask))
869-
return FunctionMetadataDifferentiabilityKind::Normal;
870-
if (bool(Data & LinearMask))
871-
return FunctionMetadataDifferentiabilityKind::Linear;
872-
return FunctionMetadataDifferentiabilityKind::NonDifferentiable;
866+
return FunctionMetadataDifferentiabilityKind(
867+
(Data & DifferentiabilityMask) >> DifferentiabilityShift);
873868
}
874869

875870
int_type getIntValue() const {

include/swift/AST/Attr.h

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ class TypeAttributes {
8585

8686
// Indicates whether the type's '@differentiable' attribute has a 'linear'
8787
// argument.
88-
bool linear = false;
88+
DifferentiabilityKind differentiabilityKind =
89+
DifferentiabilityKind::NonDifferentiable;
8990

9091
// For an opened existential type, the known ID.
9192
Optional<UUID> OpenedID;
@@ -102,14 +103,6 @@ class TypeAttributes {
102103

103104
bool isValid() const { return AtLoc.isValid(); }
104105

105-
bool isLinear() const {
106-
assert(
107-
!linear ||
108-
(linear && has(TAK_differentiable)) &&
109-
"Linear shouldn't have been true if there's no `@differentiable`");
110-
return linear;
111-
}
112-
113106
void clearAttribute(TypeAttrKind A) {
114107
AttrLocs[A] = SourceLoc();
115108
}
@@ -1790,8 +1783,9 @@ class OriginallyDefinedInAttr: public DeclAttribute {
17901783
/// Attribute that marks a function as differentiable.
17911784
///
17921785
/// Examples:
1793-
/// @differentiable(where T : FloatingPoint)
1794-
/// @differentiable(wrt: (self, x, y))
1786+
/// @differentiable(reverse)
1787+
/// @differentiable(reverse, wrt: (self, x, y))
1788+
/// @differentiable(reverse, wrt: (self, x, y) where T : FloatingPoint)
17951789
class DifferentiableAttr final
17961790
: public DeclAttribute,
17971791
private llvm::TrailingObjects<DifferentiableAttr,
@@ -1803,8 +1797,8 @@ class DifferentiableAttr final
18031797
/// May not be a valid declaration for `@differentiable` attributes.
18041798
/// Resolved during parsing and deserialization.
18051799
Decl *OriginalDeclaration = nullptr;
1806-
/// Whether this function is linear (optional).
1807-
bool Linear;
1800+
/// The differentiability kind.
1801+
DifferentiabilityKind DifferentiabilityKind;
18081802
/// The number of parsed differentiability parameters specified in 'wrt:'.
18091803
unsigned NumParsedParameters = 0;
18101804
/// The differentiability parameter indices, resolved by the type checker.
@@ -1830,25 +1824,28 @@ class DifferentiableAttr final
18301824
SourceLoc ImplicitlyInheritedDifferentiableAttrLocation;
18311825

18321826
explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
1833-
SourceRange baseRange, bool linear,
1827+
SourceRange baseRange,
1828+
enum DifferentiabilityKind diffKind,
18341829
ArrayRef<ParsedAutoDiffParameter> parameters,
18351830
TrailingWhereClause *clause);
18361831

18371832
explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
1838-
SourceRange baseRange, bool linear,
1833+
SourceRange baseRange,
1834+
enum DifferentiabilityKind diffKind,
18391835
IndexSubset *parameterIndices,
18401836
GenericSignature derivativeGenericSignature);
18411837

18421838
public:
18431839
static DifferentiableAttr *create(ASTContext &context, bool implicit,
18441840
SourceLoc atLoc, SourceRange baseRange,
1845-
bool linear,
1841+
enum DifferentiabilityKind diffKind,
18461842
ArrayRef<ParsedAutoDiffParameter> params,
18471843
TrailingWhereClause *clause);
18481844

18491845
static DifferentiableAttr *create(AbstractFunctionDecl *original,
18501846
bool implicit, SourceLoc atLoc,
1851-
SourceRange baseRange, bool linear,
1847+
SourceRange baseRange,
1848+
enum DifferentiabilityKind diffKind,
18521849
IndexSubset *parameterIndices,
18531850
GenericSignature derivativeGenSig);
18541851

@@ -1879,7 +1876,25 @@ class DifferentiableAttr final
18791876
return NumParsedParameters;
18801877
}
18811878

1882-
bool isLinear() const { return Linear; }
1879+
enum DifferentiabilityKind getDifferentiabilityKind() const {
1880+
return DifferentiabilityKind;
1881+
}
1882+
1883+
bool isNormalDifferentiability() const {
1884+
return DifferentiabilityKind == DifferentiabilityKind::Normal;
1885+
}
1886+
1887+
bool isLinearDifferentiability() const {
1888+
return DifferentiabilityKind == DifferentiabilityKind::Linear;
1889+
}
1890+
1891+
bool isForwardDifferentiability() const {
1892+
return DifferentiabilityKind == DifferentiabilityKind::Forward;
1893+
}
1894+
1895+
bool isReverseDifferentiability() const {
1896+
return DifferentiabilityKind == DifferentiabilityKind::Reverse;
1897+
}
18831898

18841899
TrailingWhereClause *getWhereClause() const { return WhereClause; }
18851900

include/swift/AST/AutoDiff.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,14 @@ class VarDecl;
4141
/// A function type differentiability kind.
4242
enum class DifferentiabilityKind : uint8_t {
4343
NonDifferentiable = 0,
44-
Normal = 1,
45-
Linear = 2
44+
// '@differentiable(_forward)', rejected by parser.
45+
Forward = 1,
46+
// '@differentiable(reverse)', supported.
47+
Reverse = 2,
48+
// '@differentiable', unsupported.
49+
Normal = 3,
50+
// '@differentiable(_linear)', unsupported.
51+
Linear = 4,
4652
};
4753

4854
/// The kind of an linear map.
@@ -74,9 +80,15 @@ struct AutoDiffDerivativeFunctionKind {
7480
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
7581
explicit AutoDiffDerivativeFunctionKind(StringRef string);
7682
operator innerty() const { return rawValue; }
77-
AutoDiffLinearMapKind getLinearMapKind() {
83+
AutoDiffLinearMapKind getLinearMapKind() const {
7884
return (AutoDiffLinearMapKind::innerty)rawValue;
7985
}
86+
DifferentiabilityKind getMinimalDifferentiabilityKind() const {
87+
switch (rawValue) {
88+
case JVP: return DifferentiabilityKind::Forward;
89+
case VJP: return DifferentiabilityKind::Reverse;
90+
}
91+
}
8092
};
8193

8294
/// A component of a SIL `@differentiable` function-typed value.
@@ -98,7 +110,7 @@ struct NormalDifferentiableFunctionTypeComponent {
98110
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
99111
};
100112

101-
/// A component of a SIL `@differentiable(linear)` function-typed value.
113+
/// A component of a SIL `@differentiable(_linear)` function-typed value.
102114
struct LinearDifferentiableFunctionTypeComponent {
103115
enum innerty : unsigned {
104116
Original = 0,

include/swift/AST/DiagnosticsParse.def

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,14 +1610,19 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
16101610
"expected a member name as second parameter in '_implements' attribute", ())
16111611

16121612
// differentiable
1613+
WARNING(attr_differentiable_expected_reverse,PointsToFirstBadToken,
1614+
"'@differentiable' has been renamed to '@differentiable(reverse)' and "
1615+
"will be removed in the next release", ())
1616+
ERROR(attr_differentiable_kind_not_supported,PointsToFirstBadToken,
1617+
"unsupported differentiability kind '%0'; only 'reverse' is supported", (StringRef))
1618+
ERROR(attr_differentiable_unknown_kind,PointsToFirstBadToken,
1619+
"unknown differentiability kind '%0'; only 'reverse' is supported", (StringRef))
16131620
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
16141621
"expected a list of parameters to differentiate with respect to", ())
16151622
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
16161623
"use 'wrt:' to specify parameters to differentiate with respect to", ())
16171624
ERROR(attr_differentiable_expected_label,none,
16181625
"expected 'wrt:' or 'where' in '@differentiable' attribute", ())
1619-
ERROR(attr_differentiable_unexpected_argument,none,
1620-
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
16211626

16221627
// differentiation `wrt` parameters clause
16231628
ERROR(expected_colon_after_label,PointsToFirstBadToken,

include/swift/AST/DiagnosticsSIL.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ ERROR(autodiff_differentiation_module_not_imported,none,
435435
"Automatic differentiation requires the '_Differentiation' module to be "
436436
"imported", ())
437437
ERROR(autodiff_conversion_to_linear_function_not_supported,none,
438-
"conversion to '@differentiable(linear)' function type is not yet "
438+
"conversion to '@differentiable(_linear)' function type is not yet "
439439
"supported", ())
440440
ERROR(autodiff_function_not_differentiable_error,none,
441441
"function is not differentiable", ())

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,14 +1278,11 @@ ERROR(c_function_pointer_from_method,none,
12781278
ERROR(c_function_pointer_from_generic_function,none,
12791279
"a C function pointer cannot be formed from a reference to a generic "
12801280
"function", ())
1281-
ERROR(unsupported_linear_to_differentiable_conversion,none,
1282-
"conversion from '@differentiable(linear)' to '@differentiable' is not "
1283-
"yet supported", ())
12841281
ERROR(invalid_autoclosure_forwarding,none,
12851282
"add () to forward @autoclosure parameter", ())
12861283
ERROR(invalid_differentiable_function_conversion_expr,none,
1287-
"a '@differentiable%select{|(linear)}0' function can only be formed from "
1288-
"a reference to a 'func' or 'init' or a literal closure", (bool))
1284+
"a '@differentiable' function can only be formed from "
1285+
"a reference to a 'func' or 'init' or a literal closure", ())
12891286
NOTE(invalid_differentiable_function_conversion_parameter,none,
12901287
"did you mean to take a '%0' closure?", (StringRef))
12911288
ERROR(invalid_autoclosure_pointer_conversion,none,
@@ -3107,9 +3104,6 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
31073104
(Identifier, Identifier))
31083105

31093106
// @differentiable
3110-
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
3111-
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
3112-
"attribute for transpose registration instead", ())
31133107
ERROR(differentiable_attr_overload_not_found,none,
31143108
"%0 does not have expected type %1", (DeclNameRef, Type))
31153109
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
@@ -4526,14 +4520,14 @@ ERROR(attr_only_on_parameters_of_differentiable,none,
45264520
ERROR(differentiable_function_type_invalid_parameter,none,
45274521
"parameter type '%0' does not conform to 'Differentiable'"
45284522
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
4529-
"function type is '@differentiable%select{|(linear)}1'"
4523+
"function type is '@differentiable%select{|(_linear)}1'"
45304524
"%select{|; did you want to add '@noDerivative' to this parameter?}2",
45314525
(StringRef, /*isLinear*/ bool,
45324526
/*hasValidDifferentiabilityParameter*/ bool))
45334527
ERROR(differentiable_function_type_invalid_result,none,
45344528
"result type '%0' does not conform to 'Differentiable'"
45354529
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
4536-
"function type is '@differentiable%select{|(linear)}1'",
4530+
"function type is '@differentiable%select{|(_linear)}1'",
45374531
(StringRef, bool))
45384532
ERROR(differentiable_function_type_no_differentiability_parameters,
45394533
none,

include/swift/AST/ExtInfo.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class ASTExtInfoBuilder {
293293
// and NumMaskBits must be updated, and they must match.
294294
//
295295
// |representation|noEscape|concurrent|async|throws|differentiability|
296-
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 9 |
296+
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 |
297297
//
298298
enum : unsigned {
299299
RepresentationMask = 0xF << 0,
@@ -302,8 +302,8 @@ class ASTExtInfoBuilder {
302302
AsyncMask = 1 << 6,
303303
ThrowsMask = 1 << 7,
304304
DifferentiabilityMaskOffset = 8,
305-
DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
306-
NumMaskBits = 10
305+
DifferentiabilityMask = 0x7 << DifferentiabilityMaskOffset,
306+
NumMaskBits = 11
307307
};
308308

309309
unsigned bits; // Naturally sized for speed.
@@ -615,8 +615,8 @@ class SILExtInfoBuilder {
615615
// If bits are added or removed, then TypeBase::SILFunctionTypeBits
616616
// and NumMaskBits must be updated, and they must match.
617617

618-
// |representation|pseudogeneric| noescape | concurrent | async | differentiability|
619-
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 9 |
618+
// |representation|pseudogeneric| noescape | concurrent | async |differentiability|
619+
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 |
620620
//
621621
enum : unsigned {
622622
RepresentationMask = 0xF << 0,
@@ -625,8 +625,8 @@ class SILExtInfoBuilder {
625625
ConcurrentMask = 1 << 6,
626626
AsyncMask = 1 << 7,
627627
DifferentiabilityMaskOffset = 8,
628-
DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
629-
NumMaskBits = 10
628+
DifferentiabilityMask = 0x7 << DifferentiabilityMaskOffset,
629+
NumMaskBits = 11
630630
};
631631

632632
unsigned bits; // Naturally sized for speed.

0 commit comments

Comments
 (0)