Skip to content

[AutoDiff] Clean up. #27718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5602,7 +5602,6 @@ Automatic Differentiation

differentiable_function
```````````````````````

::

sil-instruction ::= 'differentiable_function'
Expand Down Expand Up @@ -5638,7 +5637,6 @@ clause. In canonical SIL, a ``with_derivative`` clause is mandatory.

linear_function
```````````````

::

sil-instruction ::= 'linear_function'
Expand Down Expand Up @@ -5670,7 +5668,6 @@ In canonical SIL, a ``with`` clause is mandatory.

differentiable_function_extract
```````````````````````````````

::

sil-instruction ::= 'differentiable_function_extract'
Expand All @@ -5692,7 +5689,6 @@ Extracts the original function or a derivative function from the given

linear_function_extract
```````````````````````

::

sil-instruction ::= 'linear_function_extract'
Expand Down
28 changes: 17 additions & 11 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class SILFunctionType;
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
enum class SILLinkage : uint8_t;

enum class DifferentiabilityKind: uint8_t {
enum class DifferentiabilityKind : uint8_t {
NonDifferentiable = 0,
Normal = 1,
Linear = 2
Expand All @@ -62,10 +62,10 @@ struct AutoDiffLinearMapKind {
/// The kind of a derivative function.
struct AutoDiffDerivativeFunctionKind {
enum innerty : uint8_t {
// The Jacobian-vector products function.
JVP = 0,
// The vector-Jacobian products function.
VJP = 1
// The Jacobian-vector products function.
JVP = 0,
// The vector-Jacobian products function.
VJP = 1
} rawValue;

AutoDiffDerivativeFunctionKind() = default;
Expand All @@ -91,8 +91,8 @@ struct NormalDifferentiableFunctionTypeComponent {
: rawValue(rawValue) {}
NormalDifferentiableFunctionTypeComponent(
AutoDiffDerivativeFunctionKind kind);
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue) :
NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue)
: NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
explicit NormalDifferentiableFunctionTypeComponent(StringRef name);
operator innerty() const { return rawValue; }

Expand All @@ -108,8 +108,8 @@ struct LinearDifferentiableFunctionTypeComponent {
LinearDifferentiableFunctionTypeComponent() = default;
LinearDifferentiableFunctionTypeComponent(innerty rawValue)
: rawValue(rawValue) {}
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue) :
LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue)
: LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
explicit LinearDifferentiableFunctionTypeComponent(StringRef name);
operator innerty() const { return rawValue; }
};
Expand All @@ -132,10 +132,10 @@ class ParsedAutoDiffParameter {

public:
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, Value value)
: Loc(loc), Kind(kind), V(value) {}
: Loc(loc), Kind(kind), V(value) {}

ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, unsigned index)
: Loc(loc), Kind(kind), V(index) {}
: Loc(loc), Kind(kind), V(index) {}

static ParsedAutoDiffParameter getNamedParameter(SourceLoc loc,
Identifier name) {
Expand Down Expand Up @@ -251,6 +251,12 @@ struct AutoDiffConfig {
IndexSubset *parameterIndices;
IndexSubset *resultIndices;
GenericSignature *derivativeGenericSignature;

/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
IndexSubset *resultIndices,
GenericSignature *derivativeGenericSignature)
: parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenericSignature) {}
};

/// In conjunction with the original function declaration, identifies an
Expand Down
22 changes: 10 additions & 12 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -689,12 +689,6 @@ ERROR(sil_witness_protocol_conformance_not_found,none,
// SIL differentiability witnesses
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
"expected '%0' in differentiability witness", (StringRef))
ERROR(sil_diff_witness_expected_index_list,PointsToFirstBadToken,
"expected a space-separated list of indices, e.g. '0 1'", ())
ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken,
"expected a parameter index to differentiate with respect to", ())
ERROR(sil_diff_witness_expected_result_index,PointsToFirstBadToken,
"expected a result index to differentiate with respect to", ())

// SIL Coverage Map
ERROR(sil_coverage_func_not_found, none,
Expand Down Expand Up @@ -1596,16 +1590,20 @@ ERROR(sil_attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
"expected an comma-separated list of parameter indices, e.g. (0, 1)", ())
ERROR(sil_attr_differentiable_expected_rsquare,PointsToFirstBadToken,
"expected ']' to end 'differentiable' attribute", ())
ERROR(sil_attr_differentiable_expected_parameter_index,PointsToFirstBadToken,
"expected the index of a parameter to differentiate w.r.t.", ())
ERROR(sil_attr_differentiable_expected_source_index,PointsToFirstBadToken,
"expected the index of a result to differentiate from", ())

// SIL autodiff
ERROR(sil_inst_autodiff_attr_expected_rsquare,PointsToFirstBadToken,
ERROR(sil_autodiff_expected_lsquare,PointsToFirstBadToken,
"expected '[' to start the %0", (StringRef))
ERROR(sil_autodiff_expected_rsquare,PointsToFirstBadToken,
"expected ']' to complete the %0", (StringRef))
ERROR(sil_inst_autodiff_expected_parameter_index,PointsToFirstBadToken,
ERROR(sil_autodiff_expected_index_list,PointsToFirstBadToken,
"expected a space-separated list of indices, e.g. '0 1'", ())
ERROR(sil_autodiff_expected_index_list_label,PointsToFirstBadToken,
"expected label '%0' in index list", (StringRef))
ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken,
"expected the index of a parameter to differentiate with respect to", ())
ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken,
"expected the index of a result to differentiate from", ())
ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken,
"expected '{' to start a derivative function list", ())
ERROR(sil_inst_autodiff_operand_list_expected_comma,PointsToFirstBadToken,
Expand Down
21 changes: 9 additions & 12 deletions include/swift/SIL/SILDifferentiabilityWitness.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,9 @@ class SILDifferentiabilityWitness
SILLinkage linkage;
/// The original function.
SILFunction *originalFunction;
/// The parameter indices.
IndexSubset *parameterIndices;
/// The result indices.
IndexSubset *resultIndices;
/// The derivative generic signature (optional).
GenericSignature *derivativeGenericSignature;
/// The autodiff configuration: parameter indices, result indices, derivative
/// generic signature (optional).
AutoDiffConfig config;
/// The JVP (Jacobian-vector products) derivative function.
SILFunction *jvp;
/// The VJP (vector-Jacobian products) derivative function.
Expand All @@ -75,9 +72,8 @@ class SILDifferentiabilityWitness
SILFunction *jvp, SILFunction *vjp,
bool isSerialized, DeclAttribute *attribute)
: module(module), linkage(linkage), originalFunction(originalFunction),
parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp),
serialized(isSerialized), attribute(attribute) {}
config(parameterIndices, resultIndices, derivativeGenSig), jvp(jvp),
vjp(vjp), serialized(isSerialized), attribute(attribute) {}

public:
static SILDifferentiabilityWitness *create(
Expand All @@ -90,14 +86,15 @@ class SILDifferentiabilityWitness
SILModule &getModule() const { return module; }
SILLinkage getLinkage() const { return linkage; }
SILFunction *getOriginalFunction() const { return originalFunction; }
const AutoDiffConfig &getConfig() const { return config; }
IndexSubset *getParameterIndices() const {
return parameterIndices;
return config.parameterIndices;
}
IndexSubset *getResultIndices() const {
return resultIndices;
return config.resultIndices;
}
GenericSignature *getDerivativeGenericSignature() const {
return derivativeGenericSignature;
return config.derivativeGenericSignature;
}
SILFunction *getJVP() const { return jvp; }
SILFunction *getVJP() const { return vjp; }
Expand Down
Loading