@@ -957,8 +957,9 @@ bool Parser::parseDifferentiableAttributeArguments(
957
957
diagnose (Tok, diag::unexpected_separator, " ," );
958
958
return true ;
959
959
}
960
- // Check that token after comma is a function specifier label.
961
- if (!Tok.is (tok::identifier) || !(Tok.getText () == " jvp" ||
960
+ // Check that token after comma is 'wrt:' or a function specifier label.
961
+ if (!Tok.is (tok::identifier) || !(Tok.getText () == " wrt" ||
962
+ Tok.getText () == " jvp" ||
962
963
Tok.getText () == " vjp" )) {
963
964
diagnose (Tok, diag::attr_differentiable_expected_label);
964
965
return true ;
@@ -971,18 +972,16 @@ bool Parser::parseDifferentiableAttributeArguments(
971
972
SyntaxParsingContext ContentContext (
972
973
SyntaxContext, SyntaxKind::DifferentiableAttributeArguments);
973
974
974
- // Parse optional differentiation parameters, starting with
975
- // the 'linear' label (optinal ).
975
+ // Parse optional differentiation parameters.
976
+ // Parse 'linear' label (optional ).
976
977
if (Tok.is (tok::identifier) && Tok.getText () == " linear" ) {
977
978
linear = true ;
978
- if (consumeIfTrailingComma ())
979
- return errorAndSkipToEnd ();
980
- consumeToken ();
979
+ consumeIdentifier ();
980
+ // If no trailing comma or 'where' clause, terminate parsing arguments.
981
981
if (Tok.isNot (tok::comma) && Tok.isNot (tok::kw_where))
982
982
return false ;
983
- if (Tok.is (tok::comma)) {
984
- consumeToken ();
985
- }
983
+ if (consumeIfTrailingComma ())
984
+ return errorAndSkipToEnd ();
986
985
} else {
987
986
linear = false ;
988
987
}
0 commit comments