Skip to content

[AutoDiff] Add '@differentiable(reverse)'. #35811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions docs/ABI/Mangling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ Types
FUNCTION-KIND ::= 'E' // function type (noescape)
FUNCTION-KIND ::= 'F' // @differentiable function type
FUNCTION-KIND ::= 'G' // @differentiable function type (escaping)
FUNCTION-KIND ::= 'H' // @differentiable(linear) function type
FUNCTION-KIND ::= 'I' // @differentiable(linear) function type (escaping)
FUNCTION-KIND ::= 'H' // @differentiable(_linear) function type
FUNCTION-KIND ::= 'I' // @differentiable(_linear) function type (escaping)

C-TYPE is mangled according to the Itanium ABI, and prefixed with the length.
Non-ASCII identifiers are preserved as-is; we do not use Punycode.
Expand Down Expand Up @@ -633,9 +633,10 @@ mangled in to disambiguate.

CALLEE-ESCAPE ::= 'e' // @escaping (inverse of SIL @noescape)

DIFFERENTIABILITY-KIND ::= DIFFERENTIABLE | LINEAR
DIFFERENTIABLE ::= 'd' // @differentiable
LINEAR ::= 'l' // @differentiable(linear)
DIFFERENTIABILITY-KIND ::= 'd' // @differentiable
DIFFERENTIABILITY-KIND ::= 'l' // @differentiable(_linear)
DIFFERENTIABILITY-KIND ::= 'f' // @differentiable(_forward)
DIFFERENTIABILITY-KIND ::= 'r' // @differentiable(reverse)

CALLEE-CONVENTION ::= 'y' // @callee_unowned
CALLEE-CONVENTION ::= 'g' // @callee_guaranteed
Expand Down
8 changes: 4 additions & 4 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7002,7 +7002,7 @@ linear_function
linear_function [parameters 0] %0 : $(T) -> T with_transpose %1 : $(T) -> T

Bundles a function with its transpose function into a
``@differentiable(linear)`` function.
``@differentiable(_linear)`` function.

``[parameters ...]`` specifies parameter indices that the original function is
linear with respect to.
Expand Down Expand Up @@ -7051,11 +7051,11 @@ linear_function_extract

sil-linear-function-extractee ::= 'original' | 'transpose'

linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T
linear_function_extract [original] %0 : $@differentiable(_linear) (T) -> T
linear_function_extract [transpose] %0 : $@differentiable(_linear) (T) -> T

Extracts the original function or a transpose function from the given
``@differentiable(linear)`` function. The extractee is one of the following:
``@differentiable(_linear)`` function. The extractee is one of the following:
``[original]`` or ``[transpose]``.


Expand Down
51 changes: 23 additions & 28 deletions include/swift/ABI/MetadataValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -756,11 +756,13 @@ enum class FunctionMetadataConvention: uint8_t {
};

/// Differentiability kind for function type metadata.
/// Duplicates `DifferentiabilityKind` in AutoDiff.h.
/// Duplicates `DifferentiabilityKind` in AST/AutoDiff.h.
enum class FunctionMetadataDifferentiabilityKind: uint8_t {
NonDifferentiable = 0b00,
Normal = 0b01,
Linear = 0b11
NonDifferentiable = 0b00000,
Forward = 0b00001,
Reverse = 0b00010,
Normal = 0b00011,
Linear = 0b10000,
};

/// Flags in a function type metadata record.
Expand All @@ -770,16 +772,16 @@ class TargetFunctionTypeFlags {
// one of the flag bits could be used to identify that the rest of
// the flags is going to be stored somewhere else in the metadata.
enum : int_type {
NumParametersMask = 0x0000FFFFU,
ConventionMask = 0x00FF0000U,
ConventionShift = 16U,
ThrowsMask = 0x01000000U,
ParamFlagsMask = 0x02000000U,
EscapingMask = 0x04000000U,
DifferentiableMask = 0x08000000U,
LinearMask = 0x10000000U,
AsyncMask = 0x20000000U,
ConcurrentMask = 0x40000000U,
NumParametersMask = 0x0000FFFFU,
ConventionMask = 0x00FF0000U,
ConventionShift = 16U,
ThrowsMask = 0x01000000U,
ParamFlagsMask = 0x02000000U,
EscapingMask = 0x04000000U,
DifferentiabilityMask = 0x98000000U,
DifferentiabilityShift = 27U,
AsyncMask = 0x20000000U,
ConcurrentMask = 0x40000000U,
};
int_type Data;

Expand Down Expand Up @@ -811,13 +813,9 @@ class TargetFunctionTypeFlags {
}

constexpr TargetFunctionTypeFlags<int_type> withDifferentiabilityKind(
FunctionMetadataDifferentiabilityKind differentiability) const {
return TargetFunctionTypeFlags<int_type>(
(Data & ~DifferentiableMask & ~LinearMask) |
(differentiability == FunctionMetadataDifferentiabilityKind::Normal
? DifferentiableMask : 0) |
(differentiability == FunctionMetadataDifferentiabilityKind::Linear
? LinearMask : 0));
FunctionMetadataDifferentiabilityKind differentiabilityKind) const {
return TargetFunctionTypeFlags((Data & ~DifferentiabilityMask)
| (int_type(differentiabilityKind) << DifferentiabilityShift));
}

constexpr TargetFunctionTypeFlags<int_type>
Expand Down Expand Up @@ -860,16 +858,13 @@ class TargetFunctionTypeFlags {
bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }

bool isDifferentiable() const {
return getDifferentiabilityKind() >=
FunctionMetadataDifferentiabilityKind::Normal;
return getDifferentiabilityKind() !=
FunctionMetadataDifferentiabilityKind::NonDifferentiable;
}

FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
if (bool(Data & DifferentiableMask))
return FunctionMetadataDifferentiabilityKind::Normal;
if (bool(Data & LinearMask))
return FunctionMetadataDifferentiabilityKind::Linear;
return FunctionMetadataDifferentiabilityKind::NonDifferentiable;
return FunctionMetadataDifferentiabilityKind(
(Data & DifferentiabilityMask) >> DifferentiabilityShift);
}

int_type getIntValue() const {
Expand Down
51 changes: 33 additions & 18 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class TypeAttributes {

// Indicates whether the type's '@differentiable' attribute has a 'linear'
// argument.
bool linear = false;
DifferentiabilityKind differentiabilityKind =
DifferentiabilityKind::NonDifferentiable;

// For an opened existential type, the known ID.
Optional<UUID> OpenedID;
Expand All @@ -102,14 +103,6 @@ class TypeAttributes {

bool isValid() const { return AtLoc.isValid(); }

bool isLinear() const {
assert(
!linear ||
(linear && has(TAK_differentiable)) &&
"Linear shouldn't have been true if there's no `@differentiable`");
return linear;
}

void clearAttribute(TypeAttrKind A) {
AttrLocs[A] = SourceLoc();
}
Expand Down Expand Up @@ -1790,8 +1783,9 @@ class OriginallyDefinedInAttr: public DeclAttribute {
/// Attribute that marks a function as differentiable.
///
/// Examples:
/// @differentiable(where T : FloatingPoint)
/// @differentiable(wrt: (self, x, y))
/// @differentiable(reverse)
/// @differentiable(reverse, wrt: (self, x, y))
/// @differentiable(reverse, wrt: (self, x, y) where T : FloatingPoint)
class DifferentiableAttr final
: public DeclAttribute,
private llvm::TrailingObjects<DifferentiableAttr,
Expand All @@ -1803,8 +1797,8 @@ class DifferentiableAttr final
/// May not be a valid declaration for `@differentiable` attributes.
/// Resolved during parsing and deserialization.
Decl *OriginalDeclaration = nullptr;
/// Whether this function is linear (optional).
bool Linear;
/// The differentiability kind.
DifferentiabilityKind DifferentiabilityKind;
/// The number of parsed differentiability parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiability parameter indices, resolved by the type checker.
Expand All @@ -1830,25 +1824,28 @@ class DifferentiableAttr final
SourceLoc ImplicitlyInheritedDifferentiableAttrLocation;

explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
SourceRange baseRange,
enum DifferentiabilityKind diffKind,
ArrayRef<ParsedAutoDiffParameter> parameters,
TrailingWhereClause *clause);

explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
SourceRange baseRange,
enum DifferentiabilityKind diffKind,
IndexSubset *parameterIndices,
GenericSignature derivativeGenericSignature);

public:
static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
enum DifferentiabilityKind diffKind,
ArrayRef<ParsedAutoDiffParameter> params,
TrailingWhereClause *clause);

static DifferentiableAttr *create(AbstractFunctionDecl *original,
bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
SourceRange baseRange,
enum DifferentiabilityKind diffKind,
IndexSubset *parameterIndices,
GenericSignature derivativeGenSig);

Expand Down Expand Up @@ -1879,7 +1876,25 @@ class DifferentiableAttr final
return NumParsedParameters;
}

bool isLinear() const { return Linear; }
enum DifferentiabilityKind getDifferentiabilityKind() const {
return DifferentiabilityKind;
}

bool isNormalDifferentiability() const {
return DifferentiabilityKind == DifferentiabilityKind::Normal;
}

bool isLinearDifferentiability() const {
return DifferentiabilityKind == DifferentiabilityKind::Linear;
}

bool isForwardDifferentiability() const {
return DifferentiabilityKind == DifferentiabilityKind::Forward;
}

bool isReverseDifferentiability() const {
return DifferentiabilityKind == DifferentiabilityKind::Reverse;
}

TrailingWhereClause *getWhereClause() const { return WhereClause; }

Expand Down
20 changes: 16 additions & 4 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,14 @@ class VarDecl;
/// A function type differentiability kind.
enum class DifferentiabilityKind : uint8_t {
NonDifferentiable = 0,
Normal = 1,
Linear = 2
// '@differentiable(_forward)', rejected by parser.
Forward = 1,
// '@differentiable(reverse)', supported.
Reverse = 2,
// '@differentiable', unsupported.
Normal = 3,
// '@differentiable(_linear)', unsupported.
Linear = 4,
};

/// The kind of an linear map.
Expand Down Expand Up @@ -74,9 +80,15 @@ struct AutoDiffDerivativeFunctionKind {
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
explicit AutoDiffDerivativeFunctionKind(StringRef string);
operator innerty() const { return rawValue; }
AutoDiffLinearMapKind getLinearMapKind() {
AutoDiffLinearMapKind getLinearMapKind() const {
return (AutoDiffLinearMapKind::innerty)rawValue;
}
DifferentiabilityKind getMinimalDifferentiabilityKind() const {
switch (rawValue) {
case JVP: return DifferentiabilityKind::Forward;
case VJP: return DifferentiabilityKind::Reverse;
}
}
};

/// A component of a SIL `@differentiable` function-typed value.
Expand All @@ -98,7 +110,7 @@ struct NormalDifferentiableFunctionTypeComponent {
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
};

/// A component of a SIL `@differentiable(linear)` function-typed value.
/// A component of a SIL `@differentiable(_linear)` function-typed value.
struct LinearDifferentiableFunctionTypeComponent {
enum innerty : unsigned {
Original = 0,
Expand Down
9 changes: 7 additions & 2 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1610,14 +1610,19 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
"expected a member name as second parameter in '_implements' attribute", ())

// differentiable
WARNING(attr_differentiable_expected_reverse,PointsToFirstBadToken,
"'@differentiable' has been renamed to '@differentiable(reverse)' and "
"will be removed in the next release", ())
ERROR(attr_differentiable_kind_not_supported,PointsToFirstBadToken,
"unsupported differentiability kind '%0'; only 'reverse' is supported", (StringRef))
ERROR(attr_differentiable_unknown_kind,PointsToFirstBadToken,
"unknown differentiability kind '%0'; only 'reverse' is supported", (StringRef))
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
"expected a list of parameters to differentiate with respect to", ())
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
"use 'wrt:' to specify parameters to differentiate with respect to", ())
ERROR(attr_differentiable_expected_label,none,
"expected 'wrt:' or 'where' in '@differentiable' attribute", ())
ERROR(attr_differentiable_unexpected_argument,none,
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))

// differentiation `wrt` parameters clause
ERROR(expected_colon_after_label,PointsToFirstBadToken,
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ ERROR(autodiff_differentiation_module_not_imported,none,
"Automatic differentiation requires the '_Differentiation' module to be "
"imported", ())
ERROR(autodiff_conversion_to_linear_function_not_supported,none,
"conversion to '@differentiable(linear)' function type is not yet "
"conversion to '@differentiable(_linear)' function type is not yet "
"supported", ())
ERROR(autodiff_function_not_differentiable_error,none,
"function is not differentiable", ())
Expand Down
14 changes: 4 additions & 10 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -1278,14 +1278,11 @@ ERROR(c_function_pointer_from_method,none,
ERROR(c_function_pointer_from_generic_function,none,
"a C function pointer cannot be formed from a reference to a generic "
"function", ())
ERROR(unsupported_linear_to_differentiable_conversion,none,
"conversion from '@differentiable(linear)' to '@differentiable' is not "
"yet supported", ())
ERROR(invalid_autoclosure_forwarding,none,
"add () to forward @autoclosure parameter", ())
ERROR(invalid_differentiable_function_conversion_expr,none,
"a '@differentiable%select{|(linear)}0' function can only be formed from "
"a reference to a 'func' or 'init' or a literal closure", (bool))
"a '@differentiable' function can only be formed from "
"a reference to a 'func' or 'init' or a literal closure", ())
NOTE(invalid_differentiable_function_conversion_parameter,none,
"did you mean to take a '%0' closure?", (StringRef))
ERROR(invalid_autoclosure_pointer_conversion,none,
Expand Down Expand Up @@ -3107,9 +3104,6 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
(Identifier, Identifier))

// @differentiable
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
"attribute for transpose registration instead", ())
ERROR(differentiable_attr_overload_not_found,none,
"%0 does not have expected type %1", (DeclNameRef, Type))
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
Expand Down Expand Up @@ -4526,14 +4520,14 @@ ERROR(attr_only_on_parameters_of_differentiable,none,
ERROR(differentiable_function_type_invalid_parameter,none,
"parameter type '%0' does not conform to 'Differentiable'"
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
"function type is '@differentiable%select{|(linear)}1'"
"function type is '@differentiable%select{|(_linear)}1'"
"%select{|; did you want to add '@noDerivative' to this parameter?}2",
(StringRef, /*isLinear*/ bool,
/*hasValidDifferentiabilityParameter*/ bool))
ERROR(differentiable_function_type_invalid_result,none,
"result type '%0' does not conform to 'Differentiable'"
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
"function type is '@differentiable%select{|(linear)}1'",
"function type is '@differentiable%select{|(_linear)}1'",
(StringRef, bool))
ERROR(differentiable_function_type_no_differentiability_parameters,
none,
Expand Down
14 changes: 7 additions & 7 deletions include/swift/AST/ExtInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class ASTExtInfoBuilder {
// and NumMaskBits must be updated, and they must match.
//
// |representation|noEscape|concurrent|async|throws|differentiability|
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 9 |
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 |
//
enum : unsigned {
RepresentationMask = 0xF << 0,
Expand All @@ -302,8 +302,8 @@ class ASTExtInfoBuilder {
AsyncMask = 1 << 6,
ThrowsMask = 1 << 7,
DifferentiabilityMaskOffset = 8,
DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
NumMaskBits = 10
DifferentiabilityMask = 0x7 << DifferentiabilityMaskOffset,
NumMaskBits = 11
};

unsigned bits; // Naturally sized for speed.
Expand Down Expand Up @@ -615,8 +615,8 @@ class SILExtInfoBuilder {
// If bits are added or removed, then TypeBase::SILFunctionTypeBits
// and NumMaskBits must be updated, and they must match.

// |representation|pseudogeneric| noescape | concurrent | async | differentiability|
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 9 |
// |representation|pseudogeneric| noescape | concurrent | async |differentiability|
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 |
//
enum : unsigned {
RepresentationMask = 0xF << 0,
Expand All @@ -625,8 +625,8 @@ class SILExtInfoBuilder {
ConcurrentMask = 1 << 6,
AsyncMask = 1 << 7,
DifferentiabilityMaskOffset = 8,
DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
NumMaskBits = 10
DifferentiabilityMask = 0x7 << DifferentiabilityMaskOffset,
NumMaskBits = 11
};

unsigned bits; // Naturally sized for speed.
Expand Down
Loading