Skip to content

Commit 2df5fc4

Browse files
authored
[AutoDiff] Deprecate @differentiable(jvp:vjp:) arguments. (#28932)
Deprecate `@differentiable` attribute `jvp:` and `vjp:` arguments for derivative registration. `@derivative` attribute is the canonical way to register derivatives. Update tests. TF-1001 tracks removing `@differentiable` attribute `jvp:` and `vjp:` arguments. TF-1085 tracks removing `@differentiable(jvp:vjp:)` usages in the stdlib.
1 parent d10f51a commit 2df5fc4

12 files changed

+174
-82
lines changed

include/swift/AST/DiagnosticsParse.def

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,10 +1573,12 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
15731573
"expected a member name as second parameter in '_implements' attribute", ())
15741574

15751575
// differentiable
1576+
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
15761577
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
15771578
"expected a %0 function name", (StringRef))
15781579
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
15791580
"expected a list of parameters to differentiate with respect to", ())
1581+
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
15801582
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
15811583
"use 'wrt:' to specify parameters to differentiate with respect to", ())
15821584
ERROR(attr_differentiable_expected_label,none,
@@ -1586,6 +1588,11 @@ ERROR(differentiable_attribute_expected_rparen,none,
15861588
"expected ')' in '@differentiable' attribute", ())
15871589
ERROR(unexpected_argument_differentiable,none,
15881590
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
1591+
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
1592+
WARNING(differentiable_attr_jvp_vjp_deprecated_warning,none,
1593+
"'jvp:' and 'vjp:' arguments in '@differentiable' attribute are "
1594+
"deprecated; use '@derivative' attribute for derivative registration "
1595+
"instead", ())
15891596

15901597
// derivative
15911598
WARNING(attr_differentiating_deprecated,PointsToFirstBadToken,

lib/Parse/ParseDecl.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,13 @@ bool Parser::parseDifferentiableAttributeArguments(
10211021
{ label });
10221022
result.Name = parseDeclNameRef(result.Loc, funcDiag,
10231023
DeclNameFlag::AllowZeroArgCompoundNames | DeclNameFlag::AllowOperators);
1024+
// Emit warning for deprecated `jvp:` and `vjp:` arguments.
1025+
// TODO(TF-1001): Remove deprecated `jvp:` and `vjp:` arguments.
1026+
if (result.Loc.isValid()) {
1027+
diagnose(result.Loc.getStartLoc(),
1028+
diag::differentiable_attr_jvp_vjp_deprecated_warning)
1029+
.highlight(result.Loc.getSourceRange());
1030+
}
10241031
// If no trailing comma or 'where' clause, terminate parsing arguments.
10251032
if (Tok.isNot(tok::comma, tok::kw_where))
10261033
terminateParsingArgs = true;
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
struct TF_619: Differentiable {
22
var p: Float = 1
33

4-
@differentiable(vjp: vjpFoo)
4+
@differentiable
55
func foo(_ x: Float) -> Float {
66
return p * x
77
}
88

9-
func vjpFoo(_ x: Float) -> (Float, (Float) -> (TangentVector, Float)) {
9+
@derivative(of: foo)
10+
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> (TangentVector, Float)) {
1011
return (x, { v in (TangentVector(p: v * x), v * self.p) })
1112
}
1213
}

test/AutoDiff/Parse/differentiable_attr_parse.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@ struct Foo {
77
var x: Float
88
}
99

10+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
1011
@differentiable(vjp: foo(_:_:)) // okay
1112
func bar(_ x: Float, _: Float) -> Float {
1213
return 1 + x
1314
}
1415

16+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
1517
@differentiable(vjp: foo(_:_:) where T : FloatingPoint) // okay
1618
func bar<T : Numeric>(_ x: T, _: T) -> T {
1719
return 1 + x
1820
}
1921

22+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
2023
@differentiable(wrt: (self, x, y), vjp: foo(_:_:)) // okay
2124
func bar(_ x: Float, _ y: Float) -> Float {
2225
return 1 + x
2326
}
2427

28+
// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
2529
@differentiable(wrt: (self, x, y), jvp: bar, vjp: foo(_:_:)) // okay
2630
func bar(_ x: Float, _ y: Float) -> Float {
2731
return 1 + x
@@ -55,6 +59,7 @@ func playWellWithOtherAttrs(_ x: Float, _: Float) -> Float {
5559
}
5660

5761
@_transparent
62+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
5863
@differentiable(wrt: (self), vjp: _vjpSquareRoot) // okay
5964
public func squareRoot() -> Self {
6065
var lhs = self
@@ -109,6 +114,7 @@ func bar(_ x: Float, _: Float) -> Float {
109114
return 1 + x
110115
}
111116

117+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
112118
@differentiable(vjp: foo(_:_:), 3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
113119
func bar(_ x: Float, _: Float) -> Float {
114120
return 1 + x
@@ -139,11 +145,13 @@ func two(x: Float, y: Float) -> Float {
139145
return x + y
140146
}
141147

148+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
142149
@differentiable(vjp: foo(_:_:) // expected-error {{expected ')' in 'differentiable' attribute}}
143150
func bar(_ x: Float, _: Float) -> Float {
144151
return 1 + x
145152
}
146153

154+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
147155
@differentiable(vjp: foo(_:_:) where T) // expected-error {{expected ':' or '==' to indicate a conformance or same-type requirement}}
148156
func bar<T : Numeric>(_ x: T, _: T) -> T {
149157
return 1 + x
@@ -154,11 +162,13 @@ func bar(_ x: Float, _: Float) -> Float {
154162
return 1 + x
155163
}
156164

165+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
157166
@differentiable(vjp: foo(_:_:),) // expected-error {{unexpected ',' separator}}
158167
func bar(_ x: Float, _: Float) -> Float {
159168
return 1 + x
160169
}
161170

171+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
162172
@differentiable(vjp: foo(_:_:), where T) // expected-error {{unexpected ',' separator}}
163173
func bar<T : Numeric>(_ x: T, _: T) -> T {
164174
return 1 + x
@@ -174,6 +184,7 @@ func slope5(_ x: Float) -> Float {
174184
return 5 * x
175185
}
176186

187+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
177188
@differentiable(wrt: x, vjp: const6, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
178189
func slope5(_ x: Float) -> Float {
179190
return 6 * x

test/AutoDiff/autodiff_indirect_diagnostics.swift

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,13 @@ _ = gradient(at: 1.0, in: generic)
1414

1515
// Test unmet generic requirements.
1616

17-
@differentiable(
18-
vjp: vjpWeirdExtraRequirements
19-
where T : Differentiable & CaseIterable, T.AllCases : ExpressibleByStringLiteral
20-
)
2117
func weird<T>(_ x: T) -> T {
2218
return x
2319
}
24-
func vjpWeirdExtraRequirements<
25-
T : Differentiable & CaseIterable
26-
>(_ x: T) -> (T, (T.TangentVector) -> T.TangentVector)
27-
where T.AllCases : ExpressibleByStringLiteral
28-
{
20+
@derivative(of: weird)
21+
func vjpWeirdExtraRequirements<T : Differentiable & CaseIterable>(_ x: T) -> (
22+
value: T, pullback: (T.TangentVector) -> T.TangentVector
23+
) where T.AllCases : ExpressibleByStringLiteral {
2924
return (x, { $0 })
3025
}
3126
func weirdWrapper<T : Differentiable>(_ x: T) -> T {

test/AutoDiff/derivative_attr_type_checking.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ extension Class {
389389
// both attributes register the same derivatives. This was previously valid
390390
// but is now rejected.
391391

392+
// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated; use '@derivative' attribute for derivative registration instead}}
392393
@differentiable(jvp: jvpConsistent, vjp: vjpConsistent)
393394
func consistentSpecifiedDerivatives(_ x: Float) -> Float {
394395
return x

test/AutoDiff/differentiable_attr_access_control.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@
44
// its JVP/VJP must also be exported.
55

66
// Ok: all public.
7-
@differentiable(vjp: dfoo1)
8-
public func foo1(_ x: Float) -> Float { return 1 }
9-
public func dfoo1(x: Float) -> (Float, (Float) -> Float) { return (1, {$0}) }
7+
public func foo1(_ x: Float) -> Float { x }
8+
@derivative(of: foo1)
9+
public func dfoo1(x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() }
1010

1111
// Ok: all internal.
1212
struct CheckpointsFoo {}
13-
@differentiable(vjp: dfoo2)
14-
func foo2(_ x: Float) -> Float { return 1 }
15-
func dfoo2(_ x: Float) -> (Float, (Float) -> Float) { return (1, {$0}) }
13+
func foo2(_ x: Float) -> Float { x }
14+
@derivative(of: foo2)
15+
func dfoo2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() }
1616

1717
// Ok: all private.
18-
@differentiable(vjp: dfoo3)
19-
private func foo3(_ x: Float) -> Float { return 1 }
20-
private func dfoo3(_ x: Float) -> (Float, (Float) -> Float) { return (1, {$0}) }
18+
private func foo3(_ x: Float) -> Float { x }
19+
@derivative(of: foo3)
20+
private func dfoo3(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() }
2121

2222
// Error: vjp not exported.
23-
@differentiable(vjp: dbar1) // expected-error {{derivative function 'dbar1' is required to either be public or '@usableFromInline' because the original function 'bar1' is public or '@usableFromInline'}}
24-
public func bar1(_ x: Float) -> Float { return 1 }
25-
private func dbar1(_ x: Float) -> (Float, (Float) -> Float) { return (1, {$0}) }
23+
public func bar1(_ x: Float) -> Float { x }
24+
@derivative(of: bar1)
25+
private func dbar1(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() }

test/AutoDiff/differentiable_attr_parse.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@ struct Foo {
77
var x: Float
88
}
99

10+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
1011
@differentiable(vjp: foo(_:_:)) // okay
1112
func bar(_ x: Float, _: Float) -> Float {
1213
return 1 + x
1314
}
1415

16+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
1517
@differentiable(vjp: foo(_:_:) where T : FloatingPoint) // okay
1618
func bar<T : Numeric>(_ x: T, _: T) -> T {
1719
return 1 + x
1820
}
1921

22+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
2023
@differentiable(wrt: (self, x, y), vjp: foo(_:_:)) // okay
2124
func bar(_ x: Float, _ y: Float) -> Float {
2225
return 1 + x
2326
}
2427

28+
// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
2529
@differentiable(wrt: (self, x, y), jvp: bar, vjp: foo(_:_:)) // okay
2630
func bar(_ x: Float, _ y: Float) -> Float {
2731
return 1 + x
@@ -55,6 +59,7 @@ func playWellWithOtherAttrs(_ x: Float, _: Float) -> Float {
5559
}
5660

5761
@_transparent
62+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
5863
@differentiable(wrt: (self), vjp: _vjpSquareRoot) // okay
5964
public func squareRoot() -> Self {
6065
var lhs = self
@@ -109,6 +114,7 @@ func bar(_ x: Float, _: Float) -> Float {
109114
return 1 + x
110115
}
111116

117+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
112118
@differentiable(vjp: foo(_:_:), 3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
113119
func bar(_ x: Float, _: Float) -> Float {
114120
return 1 + x
@@ -139,11 +145,13 @@ func two(x: Float, y: Float) -> Float {
139145
return x + y
140146
}
141147

148+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
142149
@differentiable(vjp: foo(_:_:) // expected-error {{expected ')' in 'differentiable' attribute}}
143150
func bar(_ x: Float, _: Float) -> Float {
144151
return 1 + x
145152
}
146153

154+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
147155
@differentiable(vjp: foo(_:_:) where T) // expected-error {{expected ':' or '==' to indicate a conformance or same-type requirement}}
148156
func bar<T : Numeric>(_ x: T, _: T) -> T {
149157
return 1 + x
@@ -154,11 +162,13 @@ func bar(_ x: Float, _: Float) -> Float {
154162
return 1 + x
155163
}
156164

165+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
157166
@differentiable(vjp: foo(_:_:),) // expected-error {{unexpected ',' separator}}
158167
func bar(_ x: Float, _: Float) -> Float {
159168
return 1 + x
160169
}
161170

171+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
162172
@differentiable(vjp: foo(_:_:), where T) // expected-error {{unexpected ',' separator}}
163173
func bar<T : Numeric>(_ x: T, _: T) -> T {
164174
return 1 + x
@@ -174,6 +184,7 @@ func slope5(_ x: Float) -> Float {
174184
return 5 * x
175185
}
176186

187+
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
177188
@differentiable(wrt: x, vjp: const6, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
178189
func slope5(_ x: Float) -> Float {
179190
return 6 * x
@@ -185,6 +196,7 @@ func localDifferentiableDeclaration() {
185196
func foo1(_ x: Float) -> Float
186197

187198
// Not okay. Derivative registration can only be non-local.
199+
// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
188200
// expected-error @+1 {{attribute '@differentiable(jvp:vjp:)' can only be used in a non-local scope}}
189201
@differentiable(vjp: dfoo2)
190202
func foo2(_ x: Float) -> Float

0 commit comments

Comments
 (0)