Skip to content

Commit 4e3d511

Browse files
authored
---
yaml --- r: 294859 b: refs/heads/tensorflow c: 07ea212 h: refs/heads/master i: 294857: ee8655e 294855: 5d63e24
1 parent 0f878b8 commit 4e3d511

File tree

5 files changed

+59
-14
lines changed

5 files changed

+59
-14
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-04-25-a: 22f738a831d43aff2b9c9773bcb65
816816
refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-05-08-a: 7d98cc16689baba5c8a3b90a9329bdcc1a12b4e9
817817
refs/heads/cherr42: a566ad54b073c2c56ac0a705d0a5bed9743135a5
818818
"refs/heads/codable_test_comment_fix": fc8f6824f7f347e1e8db55bff62db385c5728b5a
819-
refs/heads/tensorflow: ca1e144b3cba43f3cfd465aeec411cb86b15a9d4
819+
refs/heads/tensorflow: 07ea212467d7b5643e85fdffd0249194ae37c03c
820820
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-11-a: 8126fd7a652e2f70ad6d76505239e34fb2ef3e1a
821821
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-12-a: b3fd3dd84df6717f2e2e9df58c6d7e99fed57086
822822
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-13-a: 71135119579039dc321c5f65d870050fe36efda2

branches/tensorflow/include/swift/AST/DiagnosticsParse.def

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,8 @@ ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
14981498
ERROR(attr_differentiable_missing_label,PointsToFirstBadToken,
14991499
"missing label '%0:' in '@differentiable' attribute", (StringRef))
15001500
ERROR(attr_differentiable_expected_label,none,
1501-
"expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'", ())
1501+
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
1502+
"or 'vjp:'", ())
15021503

15031504
// differentiating
15041505
ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,

branches/tensorflow/include/swift/Parse/Parser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ class Parser {
954954

955955
/// Parse the arguments inside the @differentiable attribute.
956956
bool parseDifferentiableAttributeArguments(
957-
SmallVectorImpl<ParsedAutoDiffParameter> &params,
957+
bool &linear, SmallVectorImpl<ParsedAutoDiffParameter> &params,
958958
Optional<DeclNameWithLoc> &jvpSpec, Optional<DeclNameWithLoc> &vjpSpec,
959959
TrailingWhereClause *&whereClause);
960960

branches/tensorflow/lib/Parse/ParseDecl.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -834,11 +834,12 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
834834
Optional<DeclNameWithLoc> jvpSpec;
835835
Optional<DeclNameWithLoc> vjpSpec;
836836
TrailingWhereClause *whereClause = nullptr;
837+
bool linear;
837838

838839
// Parse '('.
839840
if (consumeIf(tok::l_paren, lParenLoc)) {
840841
// Parse @differentiable attribute arguments.
841-
if (parseDifferentiableAttributeArguments(params, jvpSpec, vjpSpec,
842+
if (parseDifferentiableAttributeArguments(linear, params, jvpSpec, vjpSpec,
842843
whereClause))
843844
return makeParserError();
844845
// Parse ')'.
@@ -932,7 +933,7 @@ bool Parser::parseDifferentiationParametersClause(
932933
}
933934

934935
bool Parser::parseDifferentiableAttributeArguments(
935-
SmallVectorImpl<ParsedAutoDiffParameter> &params,
936+
bool &linear, SmallVectorImpl<ParsedAutoDiffParameter> &params,
936937
Optional<DeclNameWithLoc> &jvpSpec, Optional<DeclNameWithLoc> &vjpSpec,
937938
TrailingWhereClause *&whereClause) {
938939
StringRef AttrName = "differentiable";
@@ -956,8 +957,9 @@ bool Parser::parseDifferentiableAttributeArguments(
956957
diagnose(Tok, diag::unexpected_separator, ",");
957958
return true;
958959
}
959-
// Check that token after comma is a function specifier label.
960-
if (!Tok.is(tok::identifier) || !(Tok.getText() == "jvp" ||
960+
// Check that token after comma is 'wrt:' or a function specifier label.
961+
if (!Tok.is(tok::identifier) || !(Tok.getText() == "wrt" ||
962+
Tok.getText() == "jvp" ||
961963
Tok.getText() == "vjp")) {
962964
diagnose(Tok, diag::attr_differentiable_expected_label);
963965
return true;
@@ -970,7 +972,19 @@ bool Parser::parseDifferentiableAttributeArguments(
970972
SyntaxParsingContext ContentContext(
971973
SyntaxContext, SyntaxKind::DifferentiableAttributeArguments);
972974

973-
// Parse optional differentiation parameters, starting with the 'wrt:' label.
975+
// Parse optional differentiation parameters.
976+
// Parse 'linear' label (optional).
977+
linear = false;
978+
if (Tok.is(tok::identifier) && Tok.getText() == "linear") {
979+
linear = true;
980+
consumeToken(tok::identifier);
981+
// If no trailing comma or 'where' clause, terminate parsing arguments.
982+
if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where))
983+
return false;
984+
if (consumeIfTrailingComma())
985+
return errorAndSkipToEnd();
986+
}
987+
974988
// If 'withRespectTo' is used, make the user change it to 'wrt'.
975989
if (Tok.is(tok::identifier) && Tok.getText() == "withRespectTo") {
976990
SourceRange withRespectToRange(Tok.getLoc(), peekToken().getLoc());

branches/tensorflow/test/AutoDiff/differentiable_attr_parse.swift

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,39 @@ public func squareRoot() -> Self {
5757
return lhs
5858
}
5959

60+
@differentiable(linear) // okay
61+
func identity(_ x: Float) -> Float {
62+
return x
63+
}
64+
65+
@differentiable(linear, wrt: x) // okay
66+
func slope2(_ x: Float) -> Float {
67+
return 2 * x
68+
}
69+
70+
@differentiable(linear, wrt: x, vjp: const3) // okay
71+
func slope3(_ x: Float) -> Float {
72+
return 3 * x
73+
}
74+
6075
/// Bad
6176

62-
@differentiable(3) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
77+
@differentiable(3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
6378
func bar(_ x: Float, _: Float) -> Float {
6479
return 1 + x
6580
}
6681

67-
@differentiable(foo(_:_:)) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
82+
@differentiable(foo(_:_:)) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
6883
func bar(_ x: Float, _: Float) -> Float {
6984
return 1 + x
7085
}
7186

72-
@differentiable(vjp: foo(_:_:), 3) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
87+
@differentiable(vjp: foo(_:_:), 3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
7388
func bar(_ x: Float, _: Float) -> Float {
7489
return 1 + x
7590
}
7691

77-
@differentiable(wrt: (x), foo(_:_:)) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
92+
@differentiable(wrt: (x), foo(_:_:)) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
7893
func bar(_ x: Float, _: Float) -> Float {
7994
return 1 + x
8095
}
@@ -84,7 +99,7 @@ func bar(_ x: Float, _: Float) -> Float {
8499
return 1 + x
85100
}
86101

87-
@differentiable(wrt: x, y) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
102+
@differentiable(wrt: x, y) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
88103
func bar(_ x: Float, _ y: Float) -> Float {
89104
return 1 + x
90105
}
@@ -99,7 +114,7 @@ func bar<T : Numeric>(_ x: T, _: T) -> T {
99114
return 1 + x
100115
}
101116

102-
@differentiable(,) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
117+
@differentiable(,) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
103118
func bar(_ x: Float, _: Float) -> Float {
104119
return 1 + x
105120
}
@@ -113,3 +128,18 @@ func bar(_ x: Float, _: Float) -> Float {
113128
func bar<T : Numeric>(_ x: T, _: T) -> T {
114129
return 1 + x
115130
}
131+
132+
@differentiable(wrt: x, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
133+
func slope4(_ x: Float) -> Float {
134+
return 4 * x
135+
}
136+
137+
@differentiable(wrt: x, linear, vjp: const5) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
138+
func slope5(_ x: Float) -> Float {
139+
return 5 * x
140+
}
141+
142+
@differentiable(wrt: x, vjp: const6, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
143+
func slope5(_ x: Float) -> Float {
144+
return 6 * x
145+
}

0 commit comments

Comments
 (0)