Skip to content

Commit 83492d8

Browse files
committed
Adds is(WRT|JVP|VJP)Identifier helper function.
1 parent 298911b commit 83492d8

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

include/swift/Parse/Parser.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,20 @@ class Parser {
711711
/// Check whether the current token starts with '>'.
712712
bool startsWithGreater(Token Tok) { return startsWithSymbol(Tok, '>'); }
713713

714+
/// Returns true if token is an identifier with the given value.
715+
bool isIdentifier(Token Tok, StringRef value) {
716+
return Tok.is(tok::identifier) && Tok.getText() == value;
717+
}
718+
719+
/// Returns true if token is the identifier "wrt".
720+
bool isWRTIdentifier(Token tok) { return isIdentifier(Tok, "wrt"); }
721+
722+
/// Returns true if token is the identifier "jvp".
723+
bool isJVPIdentifier(Token Tok) { return isIdentifier(Tok, "jvp"); }
724+
725+
/// Returns true if token is the identifier "vjp".
726+
bool isVJPIdentifier(Token Tok) { return isIdentifier(Tok, "vjp"); }
727+
714728
/// Consume the starting '<' of the current token, which may either
715729
/// be a complete '<' token or some kind of operator token starting with '<',
716730
/// e.g., '<>'.

lib/Parse/ParseDecl.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,9 +1026,8 @@ bool Parser::parseDifferentiableAttributeArguments(
10261026
return true;
10271027
}
10281028
// Check that token after comma is 'wrt:' or a function specifier label.
1029-
if (!Tok.is(tok::identifier) || !(Tok.getText() == "wrt" ||
1030-
Tok.getText() == "jvp" ||
1031-
Tok.getText() == "vjp")) {
1029+
if (!(isWRTIdentifier(Tok) || isJVPIdentifier(Tok) ||
1030+
isVJPIdentifier(Tok))) {
10321031
diagnose(Tok, diag::attr_differentiable_expected_label);
10331032
return true;
10341033
}
@@ -1061,11 +1060,11 @@ bool Parser::parseDifferentiableAttributeArguments(
10611060
.fixItReplace(withRespectToRange, "wrt:");
10621061
return errorAndSkipToEnd();
10631062
}
1064-
if (Tok.is(tok::identifier) && Tok.getText() == "wrt") {
1063+
if (isWRTIdentifier(Tok)) {
10651064
if (parseDifferentiationParametersClause(params, AttrName))
10661065
return true;
10671066
// If no trailing comma or 'where' clause, terminate parsing arguments.
1068-
if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where))
1067+
if (Tok.isNot(tok::comma, tok::kw_where))
10691068
return false;
10701069
if (consumeIfTrailingComma())
10711070
return errorAndSkipToEnd();
@@ -1099,7 +1098,7 @@ bool Parser::parseDifferentiableAttributeArguments(
10991098
bool terminateParsingArgs = false;
11001099

11011100
// Parse 'jvp: <func_name>' (optional).
1102-
if (Tok.is(tok::identifier) && Tok.getText() == "jvp") {
1101+
if (isJVPIdentifier(Tok)) {
11031102
SyntaxParsingContext JvpContext(
11041103
SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
11051104
jvpSpec = DeclNameWithLoc();
@@ -1112,7 +1111,7 @@ bool Parser::parseDifferentiableAttributeArguments(
11121111
}
11131112

11141113
// Parse 'vjp: <func_name>' (optional).
1115-
if (Tok.is(tok::identifier) && Tok.getText() == "vjp") {
1114+
if (isVJPIdentifier(Tok)) {
11161115
SyntaxParsingContext VjpContext(
11171116
SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
11181117
vjpSpec = DeclNameWithLoc();

0 commit comments

Comments
 (0)