@@ -2536,49 +2536,64 @@ ParsedSyntaxResult<ParsedAttributeSyntax> Parser::parseTypeAttributeSyntax() {
2536
2536
}
2537
2537
2538
2538
// SWIFT_ENABLE_TENSORFLOW
2539
- case TAK_differentiable: {
2540
- bool linear = false ;
2541
- // Check if there is a 'linear' argument.
2542
- if (Tok.is (tok::l_paren) && peekToken ().is (tok::identifier)) {
2539
+ case TAK_differentiable:
2540
+ status |= [&]() -> ParserStatus {
2541
+ // Check if there is a 'linear' argument.
2542
+ // If next tokens are not `'(' identifier`, break early.
2543
+ if (!Tok.is (tok::l_paren) || !peekToken ().is (tok::identifier))
2544
+ return makeParserSuccess ();
2545
+
2543
2546
Parser::BacktrackingScope backtrack (*this );
2544
- consumeToken (tok::l_paren);
2547
+ SourceLoc lParenLoc = Tok.getLoc ();
2548
+ auto lParen = consumeTokenSyntax (tok::l_paren);
2545
2549
2546
2550
// Determine if we have '@differentiable(linear) (T) -> U'
2547
2551
// or '@differentiable (linear) -> U'.
2548
- if (Tok.getText () == " linear" && consumeIf (tok::identifier)) {
2552
+ if (Tok.getText () == " linear" ) {
2553
+ auto linearIdentifier = consumeTokenSyntax (tok::identifier);
2549
2554
if (Tok.is (tok::r_paren) && peekToken ().is (tok::l_paren)) {
2550
2555
// It is being used as an attribute argument, so cancel backtrack
2551
2556
// as function is linear differentiable.
2552
- linear = true ;
2553
2557
backtrack.cancelBacktrack ();
2554
- consumeToken (tok::r_paren);
2558
+ builder.useLeftParen (std::move (lParen));
2559
+ builder.useArgument (std::move (linearIdentifier));
2560
+ SourceLoc rParenLoc;
2561
+ auto rParen = parseMatchingTokenSyntax (
2562
+ tok::r_paren, rParenLoc, diag::differentiable_attribute_expected_rparen,
2563
+ lParenLoc);
2564
+ if (!rParen)
2565
+ return makeParserError ();
2566
+ builder.useRightParen (std::move (*rParen));
2555
2567
} else if (Tok.is (tok::l_paren)) {
2556
2568
// Handle invalid '@differentiable(linear (T) -> U'
2557
2569
diagnose (Tok, diag::differentiable_attribute_expected_rparen);
2558
2570
backtrack.cancelBacktrack ();
2559
- status.setIsParseError ();
2560
- break ;
2571
+ builder.useLeftParen (std::move (lParen));
2572
+ builder.useArgument (std::move (linearIdentifier));
2573
+ return makeParserError ();
2561
2574
}
2562
2575
} else if (Tok.is (tok::identifier)) {
2563
2576
// No 'linear' arg or param type, but now checking if the token is being
2564
2577
// passed in as an invalid argument to '@differentiable'.
2565
2578
auto possibleArg = Tok.getText ();
2566
2579
auto t = Tok; // get ref to the argument for clearer diagnostics.
2567
- consumeToken (tok::identifier);
2580
+ auto argIdentifier = consumeTokenSyntax (tok::identifier);
2568
2581
// Check if there is an invalid argument getting passed into
2569
2582
// '@differentiable'.
2570
2583
if (Tok.is (tok::r_paren) && peekToken ().is (tok::l_paren)) {
2571
2584
// Handling '@differentiable(wrong) (...'.
2572
2585
diagnose (t, diag::unexpected_argument_differentiable, possibleArg);
2573
- consumeToken (tok::r_paren);
2586
+ auto rParen = consumeTokenSyntax (tok::r_paren);
2574
2587
backtrack.cancelBacktrack ();
2575
- status.setIsParseError ();
2576
- break ;
2588
+ builder.useLeftParen (std::move (lParen));
2589
+ builder.useArgument (std::move (argIdentifier));
2590
+ builder.useRightParen (std::move (rParen));
2591
+ return makeParserError ();
2577
2592
}
2578
2593
}
2579
- }
2594
+ return makeParserSuccess ();
2595
+ }();
2580
2596
break ;
2581
- }
2582
2597
2583
2598
case TAK_convention:
2584
2599
status |= [&]() -> ParserStatus {
0 commit comments