Skip to content

Commit 7dde30e

Browse files
committed
[AutoDiff] Clean up.
- Store `AutoDiffConfig` in `SILDifferentiabilityWitness` instead of storing the individual components. This makes it cheaper to get an `AutoDiffConfig`. - Unify parsing logic and diagnostics. - Minor style changes.
1 parent 76729c4 commit 7dde30e

File tree

6 files changed

+77
-88
lines changed

6 files changed

+77
-88
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class SILFunctionType;
3939
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
4040
enum class SILLinkage : uint8_t;
4141

42-
enum class DifferentiabilityKind: uint8_t {
42+
enum class DifferentiabilityKind : uint8_t {
4343
NonDifferentiable = 0,
4444
Normal = 1,
4545
Linear = 2
@@ -62,10 +62,10 @@ struct AutoDiffLinearMapKind {
6262
/// The kind of a derivative function.
6363
struct AutoDiffDerivativeFunctionKind {
6464
enum innerty : uint8_t {
65-
// The Jacobian-vector products function.
66-
JVP = 0,
67-
// The vector-Jacobian products function.
68-
VJP = 1
65+
// The Jacobian-vector products function.
66+
JVP = 0,
67+
// The vector-Jacobian products function.
68+
VJP = 1
6969
} rawValue;
7070

7171
AutoDiffDerivativeFunctionKind() = default;
@@ -91,8 +91,8 @@ struct NormalDifferentiableFunctionTypeComponent {
9191
: rawValue(rawValue) {}
9292
NormalDifferentiableFunctionTypeComponent(
9393
AutoDiffDerivativeFunctionKind kind);
94-
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue) :
95-
NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
94+
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue)
95+
: NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
9696
explicit NormalDifferentiableFunctionTypeComponent(StringRef name);
9797
operator innerty() const { return rawValue; }
9898

@@ -108,8 +108,8 @@ struct LinearDifferentiableFunctionTypeComponent {
108108
LinearDifferentiableFunctionTypeComponent() = default;
109109
LinearDifferentiableFunctionTypeComponent(innerty rawValue)
110110
: rawValue(rawValue) {}
111-
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue) :
112-
LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
111+
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue)
112+
: LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
113113
explicit LinearDifferentiableFunctionTypeComponent(StringRef name);
114114
operator innerty() const { return rawValue; }
115115
};
@@ -132,10 +132,10 @@ class ParsedAutoDiffParameter {
132132

133133
public:
134134
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, Value value)
135-
: Loc(loc), Kind(kind), V(value) {}
135+
: Loc(loc), Kind(kind), V(value) {}
136136

137137
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, unsigned index)
138-
: Loc(loc), Kind(kind), V(index) {}
138+
: Loc(loc), Kind(kind), V(index) {}
139139

140140
static ParsedAutoDiffParameter getNamedParameter(SourceLoc loc,
141141
Identifier name) {
@@ -251,6 +251,12 @@ struct AutoDiffConfig {
251251
IndexSubset *parameterIndices;
252252
IndexSubset *resultIndices;
253253
GenericSignature *derivativeGenericSignature;
254+
255+
/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
256+
IndexSubset *resultIndices,
257+
GenericSignature *derivativeGenericSignature)
258+
: parameterIndices(parameterIndices), resultIndices(resultIndices),
259+
derivativeGenericSignature(derivativeGenericSignature) {}
254260
};
255261

256262
/// In conjunction with the original function declaration, identifies an

include/swift/AST/DiagnosticsParse.def

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -689,12 +689,6 @@ ERROR(sil_witness_protocol_conformance_not_found,none,
689689
// SIL differentiability witnesses
690690
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
691691
"expected '%0' in differentiability witness", (StringRef))
692-
ERROR(sil_diff_witness_expected_index_list,PointsToFirstBadToken,
693-
"expected a space-separated list of indices, e.g. '0 1'", ())
694-
ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken,
695-
"expected a parameter index to differentiate with respect to", ())
696-
ERROR(sil_diff_witness_expected_result_index,PointsToFirstBadToken,
697-
"expected a result index to differentiate with respect to", ())
698692

699693
// SIL Coverage Map
700694
ERROR(sil_coverage_func_not_found, none,
@@ -1596,16 +1590,20 @@ ERROR(sil_attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
15961590
"expected an comma-separated list of parameter indices, e.g. (0, 1)", ())
15971591
ERROR(sil_attr_differentiable_expected_rsquare,PointsToFirstBadToken,
15981592
"expected ']' to end 'differentiable' attribute", ())
1599-
ERROR(sil_attr_differentiable_expected_parameter_index,PointsToFirstBadToken,
1600-
"expected the index of a parameter to differentiate w.r.t.", ())
1601-
ERROR(sil_attr_differentiable_expected_source_index,PointsToFirstBadToken,
1602-
"expected the index of a result to differentiate from", ())
16031593

16041594
// SIL autodiff
1605-
ERROR(sil_inst_autodiff_attr_expected_rsquare,PointsToFirstBadToken,
1595+
ERROR(sil_autodiff_expected_lsquare,PointsToFirstBadToken,
1596+
"expected '[' to start the %0", (StringRef))
1597+
ERROR(sil_autodiff_expected_rsquare,PointsToFirstBadToken,
16061598
"expected ']' to complete the %0", (StringRef))
1607-
ERROR(sil_inst_autodiff_expected_parameter_index,PointsToFirstBadToken,
1599+
ERROR(sil_autodiff_expected_index_list,PointsToFirstBadToken,
1600+
"expected a space-separated list of indices, e.g. '0 1'", ())
1601+
ERROR(sil_autodiff_expected_index_list_label,PointsToFirstBadToken,
1602+
"expected label '%0' in index list", (StringRef))
1603+
ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken,
16081604
"expected the index of a parameter to differentiate with respect to", ())
1605+
ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken,
1606+
"expected the index of a result to differentiate from", ())
16091607
ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken,
16101608
"expected '{' to start a derivative function list", ())
16111609
ERROR(sil_inst_autodiff_operand_list_expected_comma,PointsToFirstBadToken,

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,9 @@ class SILDifferentiabilityWitness
4848
SILLinkage linkage;
4949
/// The original function.
5050
SILFunction *originalFunction;
51-
/// The parameter indices.
52-
IndexSubset *parameterIndices;
53-
/// The result indices.
54-
IndexSubset *resultIndices;
55-
/// The derivative generic signature (optional).
56-
GenericSignature *derivativeGenericSignature;
51+
/// The autodiff configuration: parameter indices, result indices, derivative
52+
/// generic signature (optional).
53+
AutoDiffConfig config;
5754
/// The JVP (Jacobian-vector products) derivative function.
5855
SILFunction *jvp;
5956
/// The VJP (vector-Jacobian products) derivative function.
@@ -75,9 +72,8 @@ class SILDifferentiabilityWitness
7572
SILFunction *jvp, SILFunction *vjp,
7673
bool isSerialized, DeclAttribute *attribute)
7774
: module(module), linkage(linkage), originalFunction(originalFunction),
78-
parameterIndices(parameterIndices), resultIndices(resultIndices),
79-
derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp),
80-
serialized(isSerialized), attribute(attribute) {}
75+
config(parameterIndices, resultIndices, derivativeGenSig), jvp(jvp),
76+
vjp(vjp), serialized(isSerialized), attribute(attribute) {}
8177

8278
public:
8379
static SILDifferentiabilityWitness *create(
@@ -90,14 +86,15 @@ class SILDifferentiabilityWitness
9086
SILModule &getModule() const { return module; }
9187
SILLinkage getLinkage() const { return linkage; }
9288
SILFunction *getOriginalFunction() const { return originalFunction; }
89+
const AutoDiffConfig &getConfig() const { return config; }
9390
IndexSubset *getParameterIndices() const {
94-
return parameterIndices;
91+
return config.parameterIndices;
9592
}
9693
IndexSubset *getResultIndices() const {
97-
return resultIndices;
94+
return config.resultIndices;
9895
}
9996
GenericSignature *getDerivativeGenericSignature() const {
100-
return derivativeGenericSignature;
97+
return config.derivativeGenericSignature;
10198
}
10299
SILFunction *getJVP() const { return jvp; }
103100
SILFunction *getVJP() const { return vjp; }

lib/ParseSIL/ParseSIL.cpp

Lines changed: 38 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ static bool parseDifferentiableAttr(
992992
if (P.parseSpecificIdentifier(
993993
"source", diag::sil_attr_differentiable_expected_keyword, "source") ||
994994
P.parseUnsignedInteger(SourceIndex, LastLoc,
995-
diag::sil_attr_differentiable_expected_source_index))
995+
diag::sil_autodiff_expected_result_index))
996996
return true;
997997
// Parse 'wrt'.
998998
if (P.parseSpecificIdentifier(
@@ -1005,7 +1005,7 @@ static bool parseDifferentiableAttr(
10051005
unsigned Index;
10061006
// TODO: Reject non-ascending parameter index lists.
10071007
if (P.parseUnsignedInteger(Index, LastLoc,
1008-
diag::sil_attr_differentiable_expected_parameter_index))
1008+
diag::sil_autodiff_expected_parameter_index))
10091009
return true;
10101010
ParamIndices.push_back(Index);
10111011
return false;
@@ -2939,12 +2939,11 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
29392939
while (P.Tok.is(tok::integer_literal)) {
29402940
unsigned index;
29412941
if (P.parseUnsignedInteger(index, lastLoc,
2942-
diag::sil_inst_autodiff_expected_parameter_index))
2942+
diag::sil_autodiff_expected_parameter_index))
29432943
return true;
29442944
parameterIndices.push_back(index);
29452945
}
2946-
if (P.parseToken(tok::r_square,
2947-
diag::sil_inst_autodiff_attr_expected_rsquare,
2946+
if (P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
29482947
"parameter index list"))
29492948
return true;
29502949
}
@@ -3002,12 +3001,11 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30023001
while (P.Tok.is(tok::integer_literal)) {
30033002
unsigned index;
30043003
if (P.parseUnsignedInteger(index, lastLoc,
3005-
diag::sil_inst_autodiff_expected_parameter_index))
3004+
diag::sil_autodiff_expected_parameter_index))
30063005
return true;
30073006
parameterIndices.push_back(index);
30083007
}
3009-
if (P.parseToken(tok::r_square,
3010-
diag::sil_inst_autodiff_attr_expected_rsquare,
3008+
if (P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
30113009
"parameter index list"))
30123010
return true;
30133011
}
@@ -3050,8 +3048,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30503048
diag::sil_inst_autodiff_expected_differentiable_extractee_kind) ||
30513049
parseSILIdentifierSwitch(extractee, extracteeNames,
30523050
diag::sil_inst_autodiff_expected_differentiable_extractee_kind) ||
3053-
P.parseToken(tok::r_square,
3054-
diag::sil_inst_autodiff_attr_expected_rsquare,
3051+
P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
30553052
"extractee kind"))
30563053
return true;
30573054
if (parseTypedValueRef(functionOperand, B) ||
@@ -3073,8 +3070,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30733070
diag::sil_inst_autodiff_expected_linear_extractee_kind) ||
30743071
parseSILIdentifierSwitch(extractee, extracteeNames,
30753072
diag::sil_inst_autodiff_expected_linear_extractee_kind) ||
3076-
P.parseToken(tok::r_square,
3077-
diag::sil_inst_autodiff_attr_expected_rsquare,
3073+
P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
30783074
"extractee kind"))
30793075
return true;
30803076
if (parseTypedValueRef(functionOperand, B) ||
@@ -6941,47 +6937,35 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
69416937
};
69426938

69436939
SourceLoc lastLoc = P.getEndOfPreviousLoc();
6944-
// Parse an index subset, prefaced with the given label.
6945-
auto parseIndexSubset =
6946-
[&](StringRef label, IndexSubset *& indexSubset) -> bool {
6947-
if (P.parseToken(tok::l_square, diag::sil_diff_witness_expected_token, "["))
6948-
return true;
6949-
if (P.parseSpecificIdentifier(
6950-
label, diag::sil_diff_witness_expected_token, label))
6951-
return true;
6952-
// Parse parameter index list.
6953-
SmallVector<unsigned, 8> paramIndices;
6954-
// Function that parses an index into `paramIndices`. Returns true on error.
6955-
auto parseParam = [&]() -> bool {
6940+
// Parse an index set, prefaced with the given label.
6941+
auto parseIndexSet = [&](StringRef label, SmallVectorImpl<unsigned> &indices,
6942+
const Diagnostic &parseIndexDiag) -> bool {
6943+
// Parse `[<label> <integer_literal>...]`.
6944+
if (P.parseToken(tok::l_square, diag::sil_autodiff_expected_lsquare,
6945+
"index list") ||
6946+
P.parseSpecificIdentifier(
6947+
label, diag::sil_autodiff_expected_index_list_label, label))
6948+
return true;
6949+
while (P.Tok.is(tok::integer_literal)) {
69566950
unsigned index;
6957-
// TODO: Reject non-ascending index lists.
6958-
if (P.parseUnsignedInteger(index, lastLoc,
6959-
diag::sil_diff_witness_expected_index_list))
6951+
if (P.parseUnsignedInteger(index, lastLoc, parseIndexDiag))
69606952
return true;
6961-
paramIndices.push_back(index);
6962-
return false;
6963-
};
6964-
// Parse first.
6965-
if (parseParam())
6966-
return true;
6967-
// Parse rest.
6968-
while (P.Tok.isNot(tok::r_square))
6969-
if (parseParam())
6970-
return true;
6971-
if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_token, "]"))
6953+
indices.push_back(index);
6954+
}
6955+
if (P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
6956+
"index list"))
69726957
return true;
6973-
auto maxIndexRef =
6974-
std::max_element(paramIndices.begin(), paramIndices.end());
6975-
indexSubset = IndexSubset::get(
6976-
P.Context, maxIndexRef ? *maxIndexRef + 1 : 0, paramIndices);
69776958
return false;
69786959
};
69796960
// Parse parameter and result indices.
6980-
IndexSubset *parameterIndices = nullptr;
6981-
IndexSubset *resultIndices = nullptr;
6982-
if (parseIndexSubset("parameters", parameterIndices))
6961+
SmallVector<unsigned, 8> parameterIndices;
6962+
SmallVector<unsigned, 8> resultIndices;
6963+
// Parse parameter and result indices.
6964+
if (parseIndexSet("parameters", parameterIndices,
6965+
diag::sil_autodiff_expected_parameter_index))
69836966
return true;
6984-
if (parseIndexSubset("results", resultIndices))
6967+
if (parseIndexSet("results", resultIndices,
6968+
diag::sil_autodiff_expected_result_index))
69856969
return true;
69866970

69876971
// Parse a trailing 'where' clause (optional).
@@ -6995,7 +6979,8 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
69956979
P.parseGenericWhereClause(whereLoc, derivativeRequirementReprs,
69966980
firstTypeInComplete,
69976981
/*AllowLayoutConstraints*/ false);
6998-
if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_token, "]"))
6982+
if (P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
6983+
"where clause"))
69996984
return true;
70006985
}
70016986

@@ -7053,8 +7038,13 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
70537038
return true;
70547039
}
70557040

7041+
auto origFnType = originalFn->getLoweredFunctionType();
7042+
auto *parameterIndexSet = IndexSubset::get(
7043+
P.Context, origFnType->getNumParameters(), parameterIndices);
7044+
auto *resultIndexSet = IndexSubset::get(
7045+
P.Context, origFnType->getNumResults(), resultIndices);
70567046
SILDifferentiabilityWitness::create(
7057-
M, *linkage, originalFn, parameterIndices, resultIndices,
7047+
M, *linkage, originalFn, parameterIndexSet, resultIndexSet,
70587048
derivativeGenSig, jvp, vjp, isSerialized);
70597049
return false;
70607050
}

lib/SIL/SILDifferentiabilityWitness.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,5 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
3434
}
3535

3636
SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
37-
AutoDiffConfig config{parameterIndices, resultIndices,
38-
derivativeGenericSignature};
39-
return std::make_pair(originalFunction->getName(), config);
37+
return std::make_pair(originalFunction->getName(), getConfig());
4038
}

lib/SILGen/SILGen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,8 +773,8 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
773773
if (auto *vjpDecl = diffAttr->getVJPFunction())
774774
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
775775
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
776-
AutoDiffConfig config{diffAttr->getParameterIndices(), resultIndices,
777-
diffAttr->getDerivativeGenericSignature()};
776+
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
777+
diffAttr->getDerivativeGenericSignature());
778778
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp);
779779
}
780780
}

0 commit comments

Comments
 (0)