Skip to content

Commit 29562c9

Browse files
author
marcrasi
authored
fix @differentiable(linear) type attr parsing (#27669)
1 parent a9ad846 commit 29562c9

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

lib/Parse/ASTGen.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,15 @@ TypeAttributes ASTGen::generateTypeAttributes(const AttributeListSyntax &syntax,
451451
if (indexTok.getText().getAsInteger(10, index))
452452
continue;
453453
attrs.setOpaqueReturnTypeOf(mangling.str(), index);
454+
// SWIFT_ENABLE_TENSORFLOW
455+
} else if (attr == TAK_differentiable) {
456+
if (arg) {
457+
auto argSyntax = arg->getAs<TokenSyntax>();
458+
attrs.linear = argSyntax->getTokenKind() == tok::identifier &&
459+
argSyntax->getIdentifierText() == "linear";
460+
} else {
461+
attrs.linear = false;
462+
}
454463
}
455464

456465
attrs.setAttr(attr, atLoc);

lib/Parse/ParseDecl.cpp

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2536,49 +2536,64 @@ ParsedSyntaxResult<ParsedAttributeSyntax> Parser::parseTypeAttributeSyntax() {
25362536
}
25372537

25382538
// SWIFT_ENABLE_TENSORFLOW
2539-
case TAK_differentiable: {
2540-
bool linear = false;
2541-
// Check if there is a 'linear' argument.
2542-
if (Tok.is(tok::l_paren) && peekToken().is(tok::identifier)) {
2539+
case TAK_differentiable:
2540+
status |= [&]() -> ParserStatus {
2541+
// Check if there is a 'linear' argument.
2542+
// If next tokens are not `'(' identifier`, break early.
2543+
if (!Tok.is(tok::l_paren) || !peekToken().is(tok::identifier))
2544+
return makeParserSuccess();
2545+
25432546
Parser::BacktrackingScope backtrack(*this);
2544-
consumeToken(tok::l_paren);
2547+
SourceLoc lParenLoc = Tok.getLoc();
2548+
auto lParen = consumeTokenSyntax(tok::l_paren);
25452549

25462550
// Determine if we have '@differentiable(linear) (T) -> U'
25472551
// or '@differentiable (linear) -> U'.
2548-
if (Tok.getText() == "linear" && consumeIf(tok::identifier)) {
2552+
if (Tok.getText() == "linear") {
2553+
auto linearIdentifier = consumeTokenSyntax(tok::identifier);
25492554
if (Tok.is(tok::r_paren) && peekToken().is(tok::l_paren)) {
25502555
// It is being used as an attribute argument, so cancel backtrack
25512556
// as function is linear differentiable.
2552-
linear = true;
25532557
backtrack.cancelBacktrack();
2554-
consumeToken(tok::r_paren);
2558+
builder.useLeftParen(std::move(lParen));
2559+
builder.useArgument(std::move(linearIdentifier));
2560+
SourceLoc rParenLoc;
2561+
auto rParen = parseMatchingTokenSyntax(
2562+
tok::r_paren, rParenLoc, diag::differentiable_attribute_expected_rparen,
2563+
lParenLoc);
2564+
if (!rParen)
2565+
return makeParserError();
2566+
builder.useRightParen(std::move(*rParen));
25552567
} else if (Tok.is(tok::l_paren)) {
25562568
// Handle invalid '@differentiable(linear (T) -> U'
25572569
diagnose(Tok, diag::differentiable_attribute_expected_rparen);
25582570
backtrack.cancelBacktrack();
2559-
status.setIsParseError();
2560-
break;
2571+
builder.useLeftParen(std::move(lParen));
2572+
builder.useArgument(std::move(linearIdentifier));
2573+
return makeParserError();
25612574
}
25622575
} else if (Tok.is(tok::identifier)) {
25632576
// No 'linear' arg or param type, but now checking if the token is being
25642577
// passed in as an invalid argument to '@differentiable'.
25652578
auto possibleArg = Tok.getText();
25662579
auto t = Tok; // get ref to the argument for clearer diagnostics.
2567-
consumeToken(tok::identifier);
2580+
auto argIdentifier = consumeTokenSyntax(tok::identifier);
25682581
// Check if there is an invalid argument getting passed into
25692582
// '@differentiable'.
25702583
if (Tok.is(tok::r_paren) && peekToken().is(tok::l_paren)) {
25712584
// Handling '@differentiable(wrong) (...'.
25722585
diagnose(t, diag::unexpected_argument_differentiable, possibleArg);
2573-
consumeToken(tok::r_paren);
2586+
auto rParen = consumeTokenSyntax(tok::r_paren);
25742587
backtrack.cancelBacktrack();
2575-
status.setIsParseError();
2576-
break;
2588+
builder.useLeftParen(std::move(lParen));
2589+
builder.useArgument(std::move(argIdentifier));
2590+
builder.useRightParen(std::move(rParen));
2591+
return makeParserError();
25772592
}
25782593
}
2579-
}
2594+
return makeParserSuccess();
2595+
}();
25802596
break;
2581-
}
25822597

25832598
case TAK_convention:
25842599
status |= [&]() -> ParserStatus {

0 commit comments

Comments
 (0)