Skip to content

Commit 92b0f22

Browse files
authored
[AutoDiff] Clean up. (#27718)
- Store `AutoDiffConfig` in `SILDifferentiabilityWitness` instead of storing the individual components. This makes it cheaper to get an `AutoDiffConfig`. - Unify parsing logic and diagnostics. - Minor style changes.
1 parent b575448 commit 92b0f22

File tree

8 files changed

+96
-160
lines changed

8 files changed

+96
-160
lines changed

docs/SIL.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5602,7 +5602,6 @@ Automatic Differentiation
56025602

56035603
differentiable_function
56045604
```````````````````````
5605-
56065605
::
56075606

56085607
sil-instruction ::= 'differentiable_function'
@@ -5638,7 +5637,6 @@ clause. In canonical SIL, a ``with_derivative`` clause is mandatory.
56385637

56395638
linear_function
56405639
```````````````
5641-
56425640
::
56435641

56445642
sil-instruction ::= 'linear_function'
@@ -5670,7 +5668,6 @@ In canonical SIL, a ``with`` clause is mandatory.
56705668

56715669
differentiable_function_extract
56725670
```````````````````````````````
5673-
56745671
::
56755672

56765673
sil-instruction ::= 'differentiable_function_extract'
@@ -5692,7 +5689,6 @@ Extracts the original function or a derivative function from the given
56925689

56935690
linear_function_extract
56945691
```````````````````````
5695-
56965692
::
56975693

56985694
sil-instruction ::= 'linear_function_extract'

include/swift/AST/AutoDiff.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class SILFunctionType;
3939
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
4040
enum class SILLinkage : uint8_t;
4141

42-
enum class DifferentiabilityKind: uint8_t {
42+
enum class DifferentiabilityKind : uint8_t {
4343
NonDifferentiable = 0,
4444
Normal = 1,
4545
Linear = 2
@@ -62,10 +62,10 @@ struct AutoDiffLinearMapKind {
6262
/// The kind of a derivative function.
6363
struct AutoDiffDerivativeFunctionKind {
6464
enum innerty : uint8_t {
65-
// The Jacobian-vector products function.
66-
JVP = 0,
67-
// The vector-Jacobian products function.
68-
VJP = 1
65+
// The Jacobian-vector products function.
66+
JVP = 0,
67+
// The vector-Jacobian products function.
68+
VJP = 1
6969
} rawValue;
7070

7171
AutoDiffDerivativeFunctionKind() = default;
@@ -91,8 +91,8 @@ struct NormalDifferentiableFunctionTypeComponent {
9191
: rawValue(rawValue) {}
9292
NormalDifferentiableFunctionTypeComponent(
9393
AutoDiffDerivativeFunctionKind kind);
94-
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue) :
95-
NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
94+
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue)
95+
: NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
9696
explicit NormalDifferentiableFunctionTypeComponent(StringRef name);
9797
operator innerty() const { return rawValue; }
9898

@@ -108,8 +108,8 @@ struct LinearDifferentiableFunctionTypeComponent {
108108
LinearDifferentiableFunctionTypeComponent() = default;
109109
LinearDifferentiableFunctionTypeComponent(innerty rawValue)
110110
: rawValue(rawValue) {}
111-
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue) :
112-
LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
111+
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue)
112+
: LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
113113
explicit LinearDifferentiableFunctionTypeComponent(StringRef name);
114114
operator innerty() const { return rawValue; }
115115
};
@@ -132,10 +132,10 @@ class ParsedAutoDiffParameter {
132132

133133
public:
134134
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, Value value)
135-
: Loc(loc), Kind(kind), V(value) {}
135+
: Loc(loc), Kind(kind), V(value) {}
136136

137137
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, unsigned index)
138-
: Loc(loc), Kind(kind), V(index) {}
138+
: Loc(loc), Kind(kind), V(index) {}
139139

140140
static ParsedAutoDiffParameter getNamedParameter(SourceLoc loc,
141141
Identifier name) {
@@ -251,6 +251,12 @@ struct AutoDiffConfig {
251251
IndexSubset *parameterIndices;
252252
IndexSubset *resultIndices;
253253
GenericSignature *derivativeGenericSignature;
254+
255+
/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
256+
IndexSubset *resultIndices,
257+
GenericSignature *derivativeGenericSignature)
258+
: parameterIndices(parameterIndices), resultIndices(resultIndices),
259+
derivativeGenericSignature(derivativeGenericSignature) {}
254260
};
255261

256262
/// In conjunction with the original function declaration, identifies an

include/swift/AST/DiagnosticsParse.def

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -689,12 +689,6 @@ ERROR(sil_witness_protocol_conformance_not_found,none,
689689
// SIL differentiability witnesses
690690
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
691691
"expected '%0' in differentiability witness", (StringRef))
692-
ERROR(sil_diff_witness_expected_index_list,PointsToFirstBadToken,
693-
"expected a space-separated list of indices, e.g. '0 1'", ())
694-
ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken,
695-
"expected a parameter index to differentiate with respect to", ())
696-
ERROR(sil_diff_witness_expected_result_index,PointsToFirstBadToken,
697-
"expected a result index to differentiate with respect to", ())
698692

699693
// SIL Coverage Map
700694
ERROR(sil_coverage_func_not_found, none,
@@ -1596,16 +1590,20 @@ ERROR(sil_attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
15961590
"expected an comma-separated list of parameter indices, e.g. (0, 1)", ())
15971591
ERROR(sil_attr_differentiable_expected_rsquare,PointsToFirstBadToken,
15981592
"expected ']' to end 'differentiable' attribute", ())
1599-
ERROR(sil_attr_differentiable_expected_parameter_index,PointsToFirstBadToken,
1600-
"expected the index of a parameter to differentiate w.r.t.", ())
1601-
ERROR(sil_attr_differentiable_expected_source_index,PointsToFirstBadToken,
1602-
"expected the index of a result to differentiate from", ())
16031593

16041594
// SIL autodiff
1605-
ERROR(sil_inst_autodiff_attr_expected_rsquare,PointsToFirstBadToken,
1595+
ERROR(sil_autodiff_expected_lsquare,PointsToFirstBadToken,
1596+
"expected '[' to start the %0", (StringRef))
1597+
ERROR(sil_autodiff_expected_rsquare,PointsToFirstBadToken,
16061598
"expected ']' to complete the %0", (StringRef))
1607-
ERROR(sil_inst_autodiff_expected_parameter_index,PointsToFirstBadToken,
1599+
ERROR(sil_autodiff_expected_index_list,PointsToFirstBadToken,
1600+
"expected a space-separated list of indices, e.g. '0 1'", ())
1601+
ERROR(sil_autodiff_expected_index_list_label,PointsToFirstBadToken,
1602+
"expected label '%0' in index list", (StringRef))
1603+
ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken,
16081604
"expected the index of a parameter to differentiate with respect to", ())
1605+
ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken,
1606+
"expected the index of a result to differentiate from", ())
16091607
ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken,
16101608
"expected '{' to start a derivative function list", ())
16111609
ERROR(sil_inst_autodiff_operand_list_expected_comma,PointsToFirstBadToken,

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,9 @@ class SILDifferentiabilityWitness
4848
SILLinkage linkage;
4949
/// The original function.
5050
SILFunction *originalFunction;
51-
/// The parameter indices.
52-
IndexSubset *parameterIndices;
53-
/// The result indices.
54-
IndexSubset *resultIndices;
55-
/// The derivative generic signature (optional).
56-
GenericSignature *derivativeGenericSignature;
51+
/// The autodiff configuration: parameter indices, result indices, derivative
52+
/// generic signature (optional).
53+
AutoDiffConfig config;
5754
/// The JVP (Jacobian-vector products) derivative function.
5855
SILFunction *jvp;
5956
/// The VJP (vector-Jacobian products) derivative function.
@@ -75,9 +72,8 @@ class SILDifferentiabilityWitness
7572
SILFunction *jvp, SILFunction *vjp,
7673
bool isSerialized, DeclAttribute *attribute)
7774
: module(module), linkage(linkage), originalFunction(originalFunction),
78-
parameterIndices(parameterIndices), resultIndices(resultIndices),
79-
derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp),
80-
serialized(isSerialized), attribute(attribute) {}
75+
config(parameterIndices, resultIndices, derivativeGenSig), jvp(jvp),
76+
vjp(vjp), serialized(isSerialized), attribute(attribute) {}
8177

8278
public:
8379
static SILDifferentiabilityWitness *create(
@@ -90,14 +86,15 @@ class SILDifferentiabilityWitness
9086
SILModule &getModule() const { return module; }
9187
SILLinkage getLinkage() const { return linkage; }
9288
SILFunction *getOriginalFunction() const { return originalFunction; }
89+
const AutoDiffConfig &getConfig() const { return config; }
9390
IndexSubset *getParameterIndices() const {
94-
return parameterIndices;
91+
return config.parameterIndices;
9592
}
9693
IndexSubset *getResultIndices() const {
97-
return resultIndices;
94+
return config.resultIndices;
9895
}
9996
GenericSignature *getDerivativeGenericSignature() const {
100-
return derivativeGenericSignature;
97+
return config.derivativeGenericSignature;
10198
}
10299
SILFunction *getJVP() const { return jvp; }
103100
SILFunction *getVJP() const { return vjp; }

0 commit comments

Comments
 (0)