Skip to content

Commit dddcfb8

Browse files
authored
Merge pull request #29984 from dan-zheng/TF-1168
2 parents cd4f3b0 + cef43e8 commit dddcfb8

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

lib/Parse/ParseDecl.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,8 +1176,16 @@ ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
11761176
SmallVector<ParsedAutoDiffParameter, 8> parameters;
11771177

11781178
// Parse trailing comma, if it exists, and check for errors.
1179-
auto consumeIfTrailingComma = [&]() -> bool {
1180-
if (!consumeIf(tok::comma)) return false;
1179+
auto consumeIfTrailingComma = [&](bool requireComma = false) -> bool {
1180+
if (!consumeIf(tok::comma)) {
1181+
// If comma is required but does not exist and ')' has not been reached,
1182+
// diagnose missing comma.
1183+
if (requireComma && !Tok.is(tok::r_paren)) {
1184+
diagnose(getEndOfPreviousLoc(), diag::expected_separator, ",");
1185+
return true;
1186+
}
1187+
return false;
1188+
}
11811189
// Diagnose trailing comma before ')'.
11821190
if (Tok.is(tok::r_paren)) {
11831191
diagnose(Tok, diag::unexpected_separator, ",");
@@ -1211,7 +1219,7 @@ ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
12111219
baseType, original))
12121220
return makeParserError();
12131221
}
1214-
if (consumeIfTrailingComma())
1222+
if (consumeIfTrailingComma(/*requireComma*/ true))
12151223
return makeParserError();
12161224
// Parse the optional 'wrt' differentiability parameters clause.
12171225
if (isIdentifier(Tok, "wrt") &&
@@ -1250,8 +1258,16 @@ ParserResult<TransposeAttr> Parser::parseTransposeAttribute(SourceLoc atLoc,
12501258
SmallVector<ParsedAutoDiffParameter, 8> parameters;
12511259

12521260
// Parse trailing comma, if it exists, and check for errors.
1253-
auto consumeIfTrailingComma = [&]() -> bool {
1254-
if (!consumeIf(tok::comma)) return false;
1261+
auto consumeIfTrailingComma = [&](bool requireComma = false) -> bool {
1262+
if (!consumeIf(tok::comma)) {
1263+
// If comma is required but does not exist and ')' has not been reached,
1264+
// diagnose missing comma.
1265+
if (requireComma && !Tok.is(tok::r_paren)) {
1266+
diagnose(Tok, diag::expected_separator, ",");
1267+
return true;
1268+
}
1269+
return false;
1270+
}
12551271
// Diagnose trailing comma before ')'.
12561272
if (Tok.is(tok::r_paren)) {
12571273
diagnose(Tok, diag::unexpected_separator, ",");
@@ -1286,7 +1302,7 @@ ParserResult<TransposeAttr> Parser::parseTransposeAttribute(SourceLoc atLoc,
12861302
baseType, original))
12871303
return makeParserError();
12881304
}
1289-
if (consumeIfTrailingComma())
1305+
if (consumeIfTrailingComma(/*requireComma*/ true))
12901306
return makeParserError();
12911307
// Parse the optional 'wrt' linearity parameters clause.
12921308
if (Tok.is(tok::identifier) && Tok.getText() == "wrt" &&

test/AutoDiff/Parse/derivative_attr_parse.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,15 @@ func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
8585
return (x, { $0 })
8686
}
8787

88-
func localDerivativeRegistration() {
88+
// TF-1168: missing comma before `wrt:`.
89+
// expected-error @+2 {{expected ',' separator}}
90+
// expected-error @+1 {{expected declaration}}
91+
@derivative(of: foo wrt: x)
92+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
93+
return (x, { $0 })
94+
}
95+
96+
func testLocalDerivativeRegistration() {
8997
// expected-error @+1 {{attribute '@derivative' can only be used in a non-local scope}}
9098
@derivative(of: sin)
9199
func dsin()

test/AutoDiff/Parse/transpose_attr_parse.swift

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,28 @@ func transpose(v: Float) -> Float
7878
@transpose(of: foo, wrt: (0, v))
7979
func transpose(v: Float) -> Float
8080

81-
// expected-error @+2 {{expected ')' in 'transpose' attribute}}
81+
// NOTE: The "expected ',' separator" diagnostic is not ideal.
82+
// Ideally, the diagnostic should point out that that `Swift.Float.+(_:_)` is
83+
// not a valid declaration name (missing colon after second argument label).
84+
// expected-error @+2 {{expected ',' separator}}
8285
// expected-error @+1 {{expected declaration}}
8386
@transpose(of: Swift.Float.+(_:_))
8487
func transpose(v: Float) -> Float
8588

86-
// expected-error @+2 {{expected ')' in 'transpose' attribute}}
89+
// NOTE: The "expected ',' separator" diagnostic is not ideal.
90+
// Ideally, the diagnostic should point out that that `Swift.Float.+.a` is
91+
// not a valid declaration name.
92+
// expected-error @+2 {{expected ',' separator}}
8793
// expected-error @+1 {{expected declaration}}
8894
@transpose(of: Swift.Float.+.a)
8995
func transpose(v: Float) -> Float
9096

97+
// TF-1168: missing comma before `wrt:`.
98+
// expected-error @+2 {{expected ',' separator}}
99+
// expected-error @+1 {{expected declaration}}
100+
@transpose(of: foo wrt: x)
101+
func transpose(v: Float) -> Float
102+
91103
func testLocalTransposeRegistration() {
92104
// Transpose registration can only be non-local.
93105
// expected-error @+1 {{attribute '@transpose' can only be used in a non-local scope}}

0 commit comments

Comments
 (0)