Skip to content

Commit 3f7d4c1

Browse files
committed
Pull in the upstream @differentiable attribute changes (#28198)
1 parent 2d702b3 commit 3f7d4c1

File tree

4 files changed

+42
-30
lines changed

4 files changed

+42
-30
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,23 +141,23 @@ class ParsedAutoDiffParameter {
141141
enum class Kind { Named, Ordered, Self };
142142

143143
private:
144-
SourceLoc Loc;
145-
Kind Kind;
144+
SourceLoc loc;
145+
Kind kind;
146146
union Value {
147-
struct { Identifier Name; } Named;
148-
struct { unsigned Index; } Ordered;
149-
struct {} Self;
147+
struct { Identifier name; } Named;
148+
struct { unsigned index; } Ordered;
149+
struct {} self;
150150
Value(Identifier name) : Named({name}) {}
151151
Value(unsigned index) : Ordered({index}) {}
152152
Value() {}
153-
} V;
153+
} value;
154154

155155
public:
156-
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, Value value)
157-
: Loc(loc), Kind(kind), V(value) {}
158-
159-
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, unsigned index)
160-
: Loc(loc), Kind(kind), V(index) {}
156+
ParsedAutoDiffParameter(SourceLoc loc, Kind kind, Value value)
157+
: loc(loc), kind(kind), value(value) {}
158+
159+
ParsedAutoDiffParameter(SourceLoc loc, Kind kind, unsigned index)
160+
: loc(loc), kind(kind), value(index) {}
161161

162162
static ParsedAutoDiffParameter getNamedParameter(SourceLoc loc,
163163
Identifier name) {
@@ -174,20 +174,20 @@ class ParsedAutoDiffParameter {
174174
}
175175

176176
Identifier getName() const {
177-
assert(Kind == Kind::Named);
178-
return V.Named.Name;
177+
assert(kind == Kind::Named);
178+
return value.Named.name;
179179
}
180180

181181
unsigned getIndex() const {
182-
return V.Ordered.Index;
182+
return value.Ordered.index;
183183
}
184184

185-
enum Kind getKind() const {
186-
return Kind;
185+
Kind getKind() const {
186+
return kind;
187187
}
188188

189189
SourceLoc getLoc() const {
190-
return Loc;
190+
return loc;
191191
}
192192

193193
bool isEqual(const ParsedAutoDiffParameter &other) const {

include/swift/Parse/Parser.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,20 @@ class Parser {
711711
/// Check whether the current token starts with '>'.
712712
bool startsWithGreater(Token Tok) { return startsWithSymbol(Tok, '>'); }
713713

714+
/// Returns true if token is an identifier with the given value.
715+
bool isIdentifier(Token Tok, StringRef value) {
716+
return Tok.is(tok::identifier) && Tok.getText() == value;
717+
}
718+
719+
/// Returns true if token is the identifier "wrt".
720+
bool isWRTIdentifier(Token tok) { return isIdentifier(Tok, "wrt"); }
721+
722+
/// Returns true if token is the identifier "jvp".
723+
bool isJVPIdentifier(Token Tok) { return isIdentifier(Tok, "jvp"); }
724+
725+
/// Returns true if token is the identifier "vjp".
726+
bool isVJPIdentifier(Token Tok) { return isIdentifier(Tok, "vjp"); }
727+
714728
/// Consume the starting '<' of the current token, which may either
715729
/// be a complete '<' token or some kind of operator token starting with '<',
716730
/// e.g., '<>'.

lib/Parse/ParseDecl.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,9 +1026,8 @@ bool Parser::parseDifferentiableAttributeArguments(
10261026
return true;
10271027
}
10281028
// Check that token after comma is 'wrt:' or a function specifier label.
1029-
if (!Tok.is(tok::identifier) || !(Tok.getText() == "wrt" ||
1030-
Tok.getText() == "jvp" ||
1031-
Tok.getText() == "vjp")) {
1029+
if (!(isWRTIdentifier(Tok) || isJVPIdentifier(Tok) ||
1030+
isVJPIdentifier(Tok))) {
10321031
diagnose(Tok, diag::attr_differentiable_expected_label);
10331032
return true;
10341033
}
@@ -1047,7 +1046,7 @@ bool Parser::parseDifferentiableAttributeArguments(
10471046
linear = true;
10481047
consumeToken(tok::identifier);
10491048
// If no trailing comma or 'where' clause, terminate parsing arguments.
1050-
if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where))
1049+
if (Tok.isNot(tok::comma, tok::kw_where))
10511050
return false;
10521051
if (consumeIfTrailingComma())
10531052
return errorAndSkipToEnd();
@@ -1061,11 +1060,11 @@ bool Parser::parseDifferentiableAttributeArguments(
10611060
.fixItReplace(withRespectToRange, "wrt:");
10621061
return errorAndSkipToEnd();
10631062
}
1064-
if (Tok.is(tok::identifier) && Tok.getText() == "wrt") {
1063+
if (isWRTIdentifier(Tok)) {
10651064
if (parseDifferentiationParametersClause(params, AttrName))
10661065
return true;
10671066
// If no trailing comma or 'where' clause, terminate parsing arguments.
1068-
if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where))
1067+
if (Tok.isNot(tok::comma, tok::kw_where))
10691068
return false;
10701069
if (consumeIfTrailingComma())
10711070
return errorAndSkipToEnd();
@@ -1090,7 +1089,7 @@ bool Parser::parseDifferentiableAttributeArguments(
10901089
funcDiag, /*allowOperators=*/true,
10911090
/*allowZeroArgCompoundNames=*/true);
10921091
// If no trailing comma or 'where' clause, terminate parsing arguments.
1093-
if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where))
1092+
if (Tok.isNot(tok::comma, tok::kw_where))
10941093
terminateParsingArgs = true;
10951094
return !result.Name;
10961095
};
@@ -1099,7 +1098,7 @@ bool Parser::parseDifferentiableAttributeArguments(
10991098
bool terminateParsingArgs = false;
11001099

11011100
// Parse 'jvp: <func_name>' (optional).
1102-
if (Tok.is(tok::identifier) && Tok.getText() == "jvp") {
1101+
if (isJVPIdentifier(Tok)) {
11031102
SyntaxParsingContext JvpContext(
11041103
SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
11051104
jvpSpec = DeclNameWithLoc();
@@ -1112,7 +1111,7 @@ bool Parser::parseDifferentiableAttributeArguments(
11121111
}
11131112

11141113
// Parse 'vjp: <func_name>' (optional).
1115-
if (Tok.is(tok::identifier) && Tok.getText() == "vjp") {
1114+
if (isVJPIdentifier(Tok)) {
11161115
SyntaxParsingContext VjpContext(
11171116
SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
11181117
vjpSpec = DeclNameWithLoc();
@@ -1127,8 +1126,7 @@ bool Parser::parseDifferentiableAttributeArguments(
11271126
}
11281127

11291128
// If parser has not advanced and token is not 'where' or ')', emit error.
1130-
if (Tok.getLoc() == startingLoc &&
1131-
Tok.isNot(tok::kw_where) && Tok.isNot(tok::r_paren)) {
1129+
if (Tok.getLoc() == startingLoc && Tok.isNot(tok::kw_where, tok::r_paren)) {
11321130
diagnose(Tok, diag::attr_differentiable_expected_label);
11331131
return errorAndSkipToEnd();
11341132
}

lib/ParseSIL/ParseSIL.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7023,7 +7023,7 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
70237023
SourceLoc lBraceLoc;
70247024
P.consumeIf(tok::l_brace, lBraceLoc);
70257025
// Parse JVP (optional).
7026-
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "jvp") {
7026+
if (P.isJVPIdentifier(P.Tok)) {
70277027
P.consumeToken(tok::identifier);
70287028
if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_token, ":"))
70297029
return true;
@@ -7032,7 +7032,7 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
70327032
return true;
70337033
}
70347034
// Parse VJP (optional).
7035-
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "vjp") {
7035+
if (P.isVJPIdentifier(P.Tok)) {
70367036
P.consumeToken(tok::identifier);
70377037
if (P.parseToken(tok::colon, diag::sil_diff_witness_expected_token, ":"))
70387038
return true;

0 commit comments

Comments
 (0)