Skip to content

Commit 07ea212

Browse files
authored
[AutoDiff] Be able to parse linear argument in @differentiable attribute (#25228)
- Adds a 'linear` argument to '@differentiable' attribute - 'linear' argument is optional, but when added, needs to be the first argument in the attribute
1 parent ca1e144 commit 07ea212

File tree

4 files changed

+58
-13
lines changed

4 files changed

+58
-13
lines changed

include/swift/AST/DiagnosticsParse.def

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,8 @@ ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
14981498
ERROR(attr_differentiable_missing_label,PointsToFirstBadToken,
14991499
"missing label '%0:' in '@differentiable' attribute", (StringRef))
15001500
ERROR(attr_differentiable_expected_label,none,
1501-
"expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'", ())
1501+
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
1502+
"or 'vjp:'", ())
15021503

15031504
// differentiating
15041505
ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,

include/swift/Parse/Parser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ class Parser {
954954

955955
/// Parse the arguments inside the @differentiable attribute.
956956
bool parseDifferentiableAttributeArguments(
957-
SmallVectorImpl<ParsedAutoDiffParameter> &params,
957+
bool &linear, SmallVectorImpl<ParsedAutoDiffParameter> &params,
958958
Optional<DeclNameWithLoc> &jvpSpec, Optional<DeclNameWithLoc> &vjpSpec,
959959
TrailingWhereClause *&whereClause);
960960

lib/Parse/ParseDecl.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -834,11 +834,12 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
834834
Optional<DeclNameWithLoc> jvpSpec;
835835
Optional<DeclNameWithLoc> vjpSpec;
836836
TrailingWhereClause *whereClause = nullptr;
837+
bool linear;
837838

838839
// Parse '('.
839840
if (consumeIf(tok::l_paren, lParenLoc)) {
840841
// Parse @differentiable attribute arguments.
841-
if (parseDifferentiableAttributeArguments(params, jvpSpec, vjpSpec,
842+
if (parseDifferentiableAttributeArguments(linear, params, jvpSpec, vjpSpec,
842843
whereClause))
843844
return makeParserError();
844845
// Parse ')'.
@@ -932,7 +933,7 @@ bool Parser::parseDifferentiationParametersClause(
932933
}
933934

934935
bool Parser::parseDifferentiableAttributeArguments(
935-
SmallVectorImpl<ParsedAutoDiffParameter> &params,
936+
bool &linear, SmallVectorImpl<ParsedAutoDiffParameter> &params,
936937
Optional<DeclNameWithLoc> &jvpSpec, Optional<DeclNameWithLoc> &vjpSpec,
937938
TrailingWhereClause *&whereClause) {
938939
StringRef AttrName = "differentiable";
@@ -956,8 +957,9 @@ bool Parser::parseDifferentiableAttributeArguments(
956957
diagnose(Tok, diag::unexpected_separator, ",");
957958
return true;
958959
}
959-
// Check that token after comma is a function specifier label.
960-
if (!Tok.is(tok::identifier) || !(Tok.getText() == "jvp" ||
960+
// Check that token after comma is 'wrt:' or a function specifier label.
961+
if (!Tok.is(tok::identifier) || !(Tok.getText() == "wrt" ||
962+
Tok.getText() == "jvp" ||
961963
Tok.getText() == "vjp")) {
962964
diagnose(Tok, diag::attr_differentiable_expected_label);
963965
return true;
@@ -970,7 +972,19 @@ bool Parser::parseDifferentiableAttributeArguments(
970972
SyntaxParsingContext ContentContext(
971973
SyntaxContext, SyntaxKind::DifferentiableAttributeArguments);
972974

973-
// Parse optional differentiation parameters, starting with the 'wrt:' label.
975+
// Parse optional differentiation parameters.
976+
// Parse 'linear' label (optional).
977+
linear = false;
978+
if (Tok.is(tok::identifier) && Tok.getText() == "linear") {
979+
linear = true;
980+
consumeToken(tok::identifier);
981+
// If no trailing comma or 'where' clause, terminate parsing arguments.
982+
if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where))
983+
return false;
984+
if (consumeIfTrailingComma())
985+
return errorAndSkipToEnd();
986+
}
987+
974988
// If 'withRespectTo' is used, make the user change it to 'wrt'.
975989
if (Tok.is(tok::identifier) && Tok.getText() == "withRespectTo") {
976990
SourceRange withRespectToRange(Tok.getLoc(), peekToken().getLoc());

test/AutoDiff/differentiable_attr_parse.swift

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,39 @@ public func squareRoot() -> Self {
5757
return lhs
5858
}
5959

60+
@differentiable(linear) // okay
61+
func identity(_ x: Float) -> Float {
62+
return x
63+
}
64+
65+
@differentiable(linear, wrt: x) // okay
66+
func slope2(_ x: Float) -> Float {
67+
return 2 * x
68+
}
69+
70+
@differentiable(linear, wrt: x, vjp: const3) // okay
71+
func slope3(_ x: Float) -> Float {
72+
return 3 * x
73+
}
74+
6075
/// Bad
6176

62-
@differentiable(3) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
77+
@differentiable(3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
6378
func bar(_ x: Float, _: Float) -> Float {
6479
return 1 + x
6580
}
6681

67-
@differentiable(foo(_:_:)) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
82+
@differentiable(foo(_:_:)) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
6883
func bar(_ x: Float, _: Float) -> Float {
6984
return 1 + x
7085
}
7186

72-
@differentiable(vjp: foo(_:_:), 3) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
87+
@differentiable(vjp: foo(_:_:), 3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
7388
func bar(_ x: Float, _: Float) -> Float {
7489
return 1 + x
7590
}
7691

77-
@differentiable(wrt: (x), foo(_:_:)) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
92+
@differentiable(wrt: (x), foo(_:_:)) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
7893
func bar(_ x: Float, _: Float) -> Float {
7994
return 1 + x
8095
}
@@ -84,7 +99,7 @@ func bar(_ x: Float, _: Float) -> Float {
8499
return 1 + x
85100
}
86101

87-
@differentiable(wrt: x, y) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
102+
@differentiable(wrt: x, y) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
88103
func bar(_ x: Float, _ y: Float) -> Float {
89104
return 1 + x
90105
}
@@ -99,7 +114,7 @@ func bar<T : Numeric>(_ x: T, _: T) -> T {
99114
return 1 + x
100115
}
101116

102-
@differentiable(,) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
117+
@differentiable(,) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
103118
func bar(_ x: Float, _: Float) -> Float {
104119
return 1 + x
105120
}
@@ -113,3 +128,18 @@ func bar(_ x: Float, _: Float) -> Float {
113128
func bar<T : Numeric>(_ x: T, _: T) -> T {
114129
return 1 + x
115130
}
131+
132+
@differentiable(wrt: x, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
133+
func slope4(_ x: Float) -> Float {
134+
return 4 * x
135+
}
136+
137+
@differentiable(wrt: x, linear, vjp: const5) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
138+
func slope5(_ x: Float) -> Float {
139+
return 5 * x
140+
}
141+
142+
@differentiable(wrt: x, vjp: const6, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
143+
func slope5(_ x: Float) -> Float {
144+
return 6 * x
145+
}

0 commit comments

Comments
 (0)