@@ -1026,9 +1026,8 @@ bool Parser::parseDifferentiableAttributeArguments(
1026
1026
return true ;
1027
1027
}
1028
1028
// Check that token after comma is 'wrt:' or a function specifier label.
1029
- if (!Tok.is (tok::identifier) || !(Tok.getText () == " wrt" ||
1030
- Tok.getText () == " jvp" ||
1031
- Tok.getText () == " vjp" )) {
1029
+ if (!(isWRTIdentifier (Tok) || isJVPIdentifier (Tok) ||
1030
+ isVJPIdentifier (Tok))) {
1032
1031
diagnose (Tok, diag::attr_differentiable_expected_label);
1033
1032
return true ;
1034
1033
}
@@ -1061,11 +1060,11 @@ bool Parser::parseDifferentiableAttributeArguments(
1061
1060
.fixItReplace (withRespectToRange, " wrt:" );
1062
1061
return errorAndSkipToEnd ();
1063
1062
}
1064
- if (Tok. is (tok::identifier) && Tok. getText () == " wrt " ) {
1063
+ if (isWRTIdentifier ( Tok) ) {
1065
1064
if (parseDifferentiationParametersClause (params, AttrName))
1066
1065
return true ;
1067
1066
// If no trailing comma or 'where' clause, terminate parsing arguments.
1068
- if (Tok.isNot (tok::comma) && Tok. isNot ( tok::kw_where))
1067
+ if (Tok.isNot (tok::comma, tok::kw_where))
1069
1068
return false ;
1070
1069
if (consumeIfTrailingComma ())
1071
1070
return errorAndSkipToEnd ();
@@ -1099,7 +1098,7 @@ bool Parser::parseDifferentiableAttributeArguments(
1099
1098
bool terminateParsingArgs = false ;
1100
1099
1101
1100
// Parse 'jvp: <func_name>' (optional).
1102
- if (Tok. is (tok::identifier) && Tok. getText () == " jvp " ) {
1101
+ if (isJVPIdentifier ( Tok) ) {
1103
1102
SyntaxParsingContext JvpContext (
1104
1103
SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
1105
1104
jvpSpec = DeclNameWithLoc ();
@@ -1112,7 +1111,7 @@ bool Parser::parseDifferentiableAttributeArguments(
1112
1111
}
1113
1112
1114
1113
// Parse 'vjp: <func_name>' (optional).
1115
- if (Tok. is (tok::identifier) && Tok. getText () == " vjp " ) {
1114
+ if (isVJPIdentifier ( Tok) ) {
1116
1115
SyntaxParsingContext VjpContext (
1117
1116
SyntaxContext, SyntaxKind::DifferentiableAttributeFuncSpecifier);
1118
1117
vjpSpec = DeclNameWithLoc ();
0 commit comments