Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

[AutoDiff] Use @derivative for derivative registration. #591

Merged
merged 1 commit into from
Dec 22, 2019

Conversation

dan-zheng
Copy link
Member

@dan-zheng dan-zheng commented Dec 21, 2019

Rewrite all derivative registration using @differentiable(jvp:vjp:) with
@derivative. This has no functional impact.

Keep original @differentiable attributes so that derivative functions are
publicly exposed.

When retroactive derivative registration is complete:

  • @differentiable(jvp:vjp:) will be deprecated.
  • @derivative attribute will be the canonical way to register derivatives.

Resolves TF-1076.


No remaining uses of @differentiable(vjp:) for derivative registration:

$ grep -nr "vjp:" *
# No remaining uses.

Example:

// Before: `@differentiable(vjp:)`.
extension Tensor where Scalar: Numeric {
    @differentiable(vjp: _vjpAdd where Scalar: TensorFlowFloatingPoint)
    public static func + (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
}

extension Tensor where Scalar: TensorFlowFloatingPoint {
    @inlinable
    static func _vjpAdd(lhs: Tensor, rhs: Tensor) -> (
        value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)
    ) { ... }
}
// After: `@derivative`.
extension Tensor where Scalar: Numeric {
    @differentiable(where Scalar: TensorFlowFloatingPoint)
    public static func + (lhs: Tensor, rhs: Tensor) -> Tensor { ... }
}

extension Tensor where Scalar: TensorFlowFloatingPoint {
    @inlinable
    @derivative(of: +)
    static func _vjpAdd(lhs: Tensor, rhs: Tensor) -> (
        value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)
    ) { ... }
}

Rewrite all derivative registration using `@differentiable(jvp:vjp:)` with
`@derivative`.

Keep original `@differentiable` attributes so that derivative functions are
publicly exposed.

When retroactive derivative registration is complete:
- `@differentiable(jvp:vjp:)` will be deprecated.
- `@derivative` attribute will be the canonical way to register derivatives.

Resolves TF-1076.
@dan-zheng dan-zheng force-pushed the derivative-attr-everywhere branch from 3f63065 to d1ca725 Compare December 21, 2019 20:39
Copy link
Contributor

@saeta saeta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, really nice work. Just one nit about the removed test, but otherwise LGTM. (Eagerly approving to not block you over the weekend.)

static func _vjpScalarInit(_ value: __owned Scalar, on device: Device = Device.getDefault
) -> (Tensor, (Tensor) -> Scalar) {
@derivative(of: init(_:on:))
static func _vjpScalarInit(_ value: __owned Scalar, on device: Device = Device.getDefault) -> (
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dan-zheng Perhaps it would make sense to rename _vjpXXX functions to _derivativeXXX?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly! The best name is func _ (a private anonymous function), which is used in the differentiable programming manifesto but not yet implemented.

When linear functions and transposition are done, users will register only primitive JVP functions returning a @differentiable(linear) differential function, and primitive VJP functions will be removed. At that point, @derivative functions will unambiguously be JVP functions, so derivativeXXX names become appropriate and vjpXXX names become obsolete.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants