Skip to content

[AutoDiff] Deprecate @differentiable(jvp:vjp:) arguments. #28932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1573,10 +1573,12 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
"expected a member name as second parameter in '_implements' attribute", ())

// differentiable
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
"expected a %0 function name", (StringRef))
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
"expected a list of parameters to differentiate with respect to", ())
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
"use 'wrt:' to specify parameters to differentiate with respect to", ())
ERROR(attr_differentiable_expected_label,none,
Expand All @@ -1586,6 +1588,11 @@ ERROR(differentiable_attribute_expected_rparen,none,
"expected ')' in '@differentiable' attribute", ())
ERROR(unexpected_argument_differentiable,none,
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
WARNING(differentiable_attr_jvp_vjp_deprecated_warning,none,
"'jvp:' and 'vjp:' arguments in '@differentiable' attribute are "
"deprecated; use '@derivative' attribute for derivative registration "
"instead", ())

// derivative
WARNING(attr_differentiating_deprecated,PointsToFirstBadToken,
Expand Down
7 changes: 7 additions & 0 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,13 @@ bool Parser::parseDifferentiableAttributeArguments(
{ label });
result.Name = parseDeclNameRef(result.Loc, funcDiag,
DeclNameFlag::AllowZeroArgCompoundNames | DeclNameFlag::AllowOperators);
// Emit warning for deprecated `jvp:` and `vjp:` arguments.
// TODO(TF-1001): Remove deprecated `jvp:` and `vjp:` arguments.
if (result.Loc.isValid()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: result.Loc may not be valid if parseDeclNameRef produces an error.
This can happen when jvp:vjp: specifies a non-func declaration, e.g. an init or subscript.

diagnose(result.Loc.getStartLoc(),
diag::differentiable_attr_jvp_vjp_deprecated_warning)
.highlight(result.Loc.getSourceRange());
}
// If no trailing comma or 'where' clause, terminate parsing arguments.
if (Tok.isNot(tok::comma, tok::kw_where))
terminateParsingArgs = true;
Expand Down
5 changes: 3 additions & 2 deletions test/AutoDiff/Inputs/silgen_thunking_other_module.swift
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
struct TF_619: Differentiable {
var p: Float = 1

@differentiable(vjp: vjpFoo)
@differentiable
func foo(_ x: Float) -> Float {
return p * x
}

func vjpFoo(_ x: Float) -> (Float, (Float) -> (TangentVector, Float)) {
@derivative(of: foo)
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> (TangentVector, Float)) {
return (x, { v in (TangentVector(p: v * x), v * self.p) })
}
}
11 changes: 11 additions & 0 deletions test/AutoDiff/Parse/differentiable_attr_parse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@ struct Foo {
var x: Float
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:)) // okay
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:) where T : FloatingPoint) // okay
func bar<T : Numeric>(_ x: T, _: T) -> T {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(wrt: (self, x, y), vjp: foo(_:_:)) // okay
func bar(_ x: Float, _ y: Float) -> Float {
return 1 + x
}

// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(wrt: (self, x, y), jvp: bar, vjp: foo(_:_:)) // okay
func bar(_ x: Float, _ y: Float) -> Float {
return 1 + x
Expand Down Expand Up @@ -55,6 +59,7 @@ func playWellWithOtherAttrs(_ x: Float, _: Float) -> Float {
}

@_transparent
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(wrt: (self), vjp: _vjpSquareRoot) // okay
public func squareRoot() -> Self {
var lhs = self
Expand Down Expand Up @@ -109,6 +114,7 @@ func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:), 3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
Expand Down Expand Up @@ -139,11 +145,13 @@ func two(x: Float, y: Float) -> Float {
return x + y
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:) // expected-error {{expected ')' in 'differentiable' attribute}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:) where T) // expected-error {{expected ':' or '==' to indicate a conformance or same-type requirement}}
func bar<T : Numeric>(_ x: T, _: T) -> T {
return 1 + x
Expand All @@ -154,11 +162,13 @@ func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:),) // expected-error {{unexpected ',' separator}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:), where T) // expected-error {{unexpected ',' separator}}
func bar<T : Numeric>(_ x: T, _: T) -> T {
return 1 + x
Expand All @@ -174,6 +184,7 @@ func slope5(_ x: Float) -> Float {
return 5 * x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(wrt: x, vjp: const6, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
func slope5(_ x: Float) -> Float {
return 6 * x
Expand Down
13 changes: 4 additions & 9 deletions test/AutoDiff/autodiff_indirect_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,13 @@ _ = gradient(at: 1.0, in: generic)

// Test unmet generic requirements.

@differentiable(
vjp: vjpWeirdExtraRequirements
where T : Differentiable & CaseIterable, T.AllCases : ExpressibleByStringLiteral
)
func weird<T>(_ x: T) -> T {
return x
}
func vjpWeirdExtraRequirements<
T : Differentiable & CaseIterable
>(_ x: T) -> (T, (T.TangentVector) -> T.TangentVector)
where T.AllCases : ExpressibleByStringLiteral
{
@derivative(of: weird)
func vjpWeirdExtraRequirements<T : Differentiable & CaseIterable>(_ x: T) -> (
value: T, pullback: (T.TangentVector) -> T.TangentVector
) where T.AllCases : ExpressibleByStringLiteral {
return (x, { $0 })
}
func weirdWrapper<T : Differentiable>(_ x: T) -> T {
Expand Down
1 change: 1 addition & 0 deletions test/AutoDiff/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ extension Class {
// both attributes register the same derivatives. This was previously valid
// but is now rejected.

// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated; use '@derivative' attribute for derivative registration instead}}
@differentiable(jvp: jvpConsistent, vjp: vjpConsistent)
func consistentSpecifiedDerivatives(_ x: Float) -> Float {
return x
Expand Down
24 changes: 12 additions & 12 deletions test/AutoDiff/differentiable_attr_access_control.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@
// its JVP/VJP must also be exported.

// Ok: all public.
@differentiable(vjp: dfoo1)
public func foo1(_ x: Float) -> Float { return 1 }
public func dfoo1(x: Float) -> (Float, (Float) -> Float) { return (1, {$0}) }
public func foo1(_ x: Float) -> Float { x }
@derivative(of: foo1)
public func dfoo1(x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() }

// Ok: all internal.
struct CheckpointsFoo {}
@differentiable(vjp: dfoo2)
func foo2(_ x: Float) -> Float { return 1 }
func dfoo2(_ x: Float) -> (Float, (Float) -> Float) { return (1, {$0}) }
func foo2(_ x: Float) -> Float { x }
@derivative(of: foo2)
func dfoo2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() }

// Ok: all private.
@differentiable(vjp: dfoo3)
private func foo3(_ x: Float) -> Float { return 1 }
private func dfoo3(_ x: Float) -> (Float, (Float) -> Float) { return (1, {$0}) }
private func foo3(_ x: Float) -> Float { x }
@derivative(of: foo3)
private func dfoo3(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() }

// Error: vjp not exported.
@differentiable(vjp: dbar1) // expected-error {{derivative function 'dbar1' is required to either be public or '@usableFromInline' because the original function 'bar1' is public or '@usableFromInline'}}
public func bar1(_ x: Float) -> Float { return 1 }
private func dbar1(_ x: Float) -> (Float, (Float) -> Float) { return (1, {$0}) }
public func bar1(_ x: Float) -> Float { x }
@derivative(of: bar1)
private func dbar1(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() }
12 changes: 12 additions & 0 deletions test/AutoDiff/differentiable_attr_parse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@ struct Foo {
var x: Float
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:)) // okay
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:) where T : FloatingPoint) // okay
func bar<T : Numeric>(_ x: T, _: T) -> T {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(wrt: (self, x, y), vjp: foo(_:_:)) // okay
func bar(_ x: Float, _ y: Float) -> Float {
return 1 + x
}

// expected-warning @+1 2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(wrt: (self, x, y), jvp: bar, vjp: foo(_:_:)) // okay
func bar(_ x: Float, _ y: Float) -> Float {
return 1 + x
Expand Down Expand Up @@ -55,6 +59,7 @@ func playWellWithOtherAttrs(_ x: Float, _: Float) -> Float {
}

@_transparent
// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(wrt: (self), vjp: _vjpSquareRoot) // okay
public func squareRoot() -> Self {
var lhs = self
Expand Down Expand Up @@ -109,6 +114,7 @@ func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:), 3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
Expand Down Expand Up @@ -139,11 +145,13 @@ func two(x: Float, y: Float) -> Float {
return x + y
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:) // expected-error {{expected ')' in 'differentiable' attribute}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:) where T) // expected-error {{expected ':' or '==' to indicate a conformance or same-type requirement}}
func bar<T : Numeric>(_ x: T, _: T) -> T {
return 1 + x
Expand All @@ -154,11 +162,13 @@ func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:),) // expected-error {{unexpected ',' separator}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(vjp: foo(_:_:), where T) // expected-error {{unexpected ',' separator}}
func bar<T : Numeric>(_ x: T, _: T) -> T {
return 1 + x
Expand All @@ -174,6 +184,7 @@ func slope5(_ x: Float) -> Float {
return 5 * x
}

// expected-warning @+1 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
@differentiable(wrt: x, vjp: const6, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
func slope5(_ x: Float) -> Float {
return 6 * x
Expand All @@ -185,6 +196,7 @@ func localDifferentiableDeclaration() {
func foo1(_ x: Float) -> Float

// Not okay. Derivative registration can only be non-local.
// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
// expected-error @+1 {{attribute '@differentiable(jvp:vjp:)' can only be used in a non-local scope}}
@differentiable(vjp: dfoo2)
func foo2(_ x: Float) -> Float
Expand Down
Loading