Skip to content

Commit cec88da

Browse files
authored
[AutoDiff] Be able to parse linear argument in @differentiating attribute (#25257)
- Adds a 'linear` argument to '@differentiating' attribute - 'linear' argument is optional, but when added, needs to be the second argument in the attribute
1 parent 8bcee7b commit cec88da

File tree

3 files changed

+136
-6
lines changed

3 files changed

+136
-6
lines changed

include/swift/AST/DiagnosticsParse.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,8 @@ ERROR(attr_differentiable_expected_label,none,
15041504
// differentiating
15051505
ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,
15061506
"expected an original function name", ())
1507+
ERROR(attr_differentiating_expected_label_linear_or_wrt,none,
1508+
"expected either 'linear' or 'wrt:'", ())
15071509

15081510
// differentiation `wrt` parameters clause
15091511
ERROR(expected_colon_after_label,PointsToFirstBadToken,

lib/Parse/ParseDecl.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -829,12 +829,12 @@ ParserResult<DifferentiableAttr>
829829
Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
830830
StringRef AttrName = "differentiable";
831831
SourceLoc lParenLoc = loc, rParenLoc = loc;
832-
832+
833+
bool linear;
833834
SmallVector<ParsedAutoDiffParameter, 8> params;
834835
Optional<DeclNameWithLoc> jvpSpec;
835836
Optional<DeclNameWithLoc> vjpSpec;
836837
TrailingWhereClause *whereClause = nullptr;
837-
bool linear;
838838

839839
// Parse '('.
840840
if (consumeIf(tok::l_paren, lParenLoc)) {
@@ -1084,7 +1084,25 @@ Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {
10841084
StringRef AttrName = "differentiating";
10851085
SourceLoc lParenLoc = loc, rParenLoc = loc;
10861086
DeclNameWithLoc original;
1087+
bool linear = false;
10871088
SmallVector<ParsedAutoDiffParameter, 8> params;
1089+
1090+
// Parse trailing comma, if it exists, and check for errors.
1091+
auto consumeIfTrailingComma = [&]() -> bool {
1092+
if (!consumeIf(tok::comma)) return false;
1093+
// Diagnose trailing comma before ')'.
1094+
if (Tok.is(tok::r_paren)) {
1095+
diagnose(Tok, diag::unexpected_separator, ",");
1096+
return true;
1097+
}
1098+
// Check that token after comma is 'linear' or 'wrt:'.
1099+
if (!Tok.is(tok::identifier) ||
1100+
!(Tok.getText() == "linear" || Tok.getText() == "wrt")) {
1101+
diagnose(Tok, diag::attr_differentiating_expected_label_linear_or_wrt);
1102+
return true;
1103+
}
1104+
return false;
1105+
};
10881106

10891107
// Parse '('.
10901108
if (!consumeIf(tok::l_paren, lParenLoc)) {
@@ -1105,11 +1123,21 @@ Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {
11051123
/*afterDot*/ false, original.Loc,
11061124
diag::attr_differentiating_expected_original_name,
11071125
/*allowOperators*/ true, /*allowZeroArgCompoundNames*/ true);
1126+
1127+
if (consumeIfTrailingComma())
1128+
return makeParserError();
11081129
}
1109-
1110-
// Parse the optional comma and `wrt` differentiation parameters clause.
1111-
if (consumeIf(tok::comma) &&
1112-
Tok.is(tok::identifier) && Tok.getText() == "wrt" &&
1130+
1131+
// Parse the optional 'linear' differentiation flag.
1132+
if (Tok.is(tok::identifier) && Tok.getText() == "linear") {
1133+
linear = true;
1134+
consumeToken(tok::identifier);
1135+
if (consumeIfTrailingComma())
1136+
return makeParserError();
1137+
}
1138+
1139+
// Parse the optional 'wrt' differentiation parameters clause.
1140+
if (Tok.is(tok::identifier) && Tok.getText() == "wrt" &&
11131141
parseDifferentiationParametersClause(params, AttrName))
11141142
return makeParserError();
11151143
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// RUN: %target-swift-frontend -parse -verify %s
2+
3+
/// Good
4+
5+
@differentiating(sin) // ok
6+
func jvpSin(x: @nondiff Float)
7+
-> (value: Float, differential: (Float)-> (Float)) {
8+
return (x, { $0 })
9+
}
10+
11+
@differentiating(sin, wrt: x) // ok
12+
func vjpSin(x: Float) -> (value: Float, pullback: (Float) -> Float) {
13+
return (x, { $0 })
14+
}
15+
16+
@differentiating(add, wrt: (x, y)) // ok
17+
func vjpAdd(x: Float, y: Float)
18+
-> (value: Float, pullback: (Float) -> (Float, Float)) {
19+
return (x + y, { ($0, $0) })
20+
}
21+
22+
extension AdditiveArithmetic where Self : Differentiable {
23+
@differentiating(+) // ok
24+
static func vjpPlus(x: Self, y: Self) -> (value: Self,
25+
pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) {
26+
return (x + y, { v in (v, v) })
27+
}
28+
}
29+
30+
@differentiating(linear) // ok
31+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
32+
return (x, { $0 })
33+
}
34+
35+
@differentiating(linear, linear) // ok
36+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
37+
return (x, { $0 })
38+
}
39+
40+
@differentiating(foo, linear) // ok
41+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
42+
return (x, { $0 })
43+
}
44+
45+
@differentiating(foo, linear, wrt: x) // ok
46+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
47+
return (x, { $0 })
48+
}
49+
50+
/// Bad
51+
52+
// expected-error @+3 {{expected an original function name}}
53+
// expected-error @+2 {{expected ')' in 'differentiating' attribute}}
54+
// expected-error @+1 {{expected declaration}}
55+
@differentiating(3)
56+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
57+
return (x, { $0 })
58+
}
59+
60+
// expected-error @+2 {{expected either 'linear' or 'wrt:'}}
61+
// expected-error @+1 {{expected declaration}}
62+
@differentiating(linear, foo)
63+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
64+
return (x, { $0 })
65+
}
66+
67+
// expected-error @+2 {{expected ')' in 'differentiating' attribute}}
68+
// expected-error @+1 {{expected declaration}}
69+
@differentiating(foo, wrt: x, linear)
70+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
71+
return (x, { $0 })
72+
}
73+
74+
// expected-error @+2 {{unexpected ',' separator}}
75+
// expected-error @+1 {{expected declaration}}
76+
@differentiating(foo,)
77+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
78+
return (x, { $0 })
79+
}
80+
81+
// expected-error @+2 {{expected ')' in 'differentiating' attribute}}
82+
// expected-error @+1 {{expected declaration}}
83+
@differentiating(foo, wrt: x,)
84+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
85+
return (x, { $0 })
86+
}
87+
88+
// expected-error @+2 {{expected either 'linear' or 'wrt:'}}
89+
// expected-error @+1 {{expected declaration}}
90+
@differentiating(linear, foo,)
91+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
92+
return (x, { $0 })
93+
}
94+
95+
// expected-error @+2 {{unexpected ',' separator}}
96+
// expected-error @+1 {{expected declaration}}
97+
@differentiating(linear,)
98+
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
99+
return (x, { $0 })
100+
}

0 commit comments

Comments
 (0)