Skip to content

[AutoDiff] [stdlib] Add 'zeroTangentVector' property to 'Differentiable'. #26828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

rxwei
Copy link
Contributor

@rxwei rxwei commented Aug 25, 2019

Zero tangent vector is necessary for optimizations on models with an array of parameters, especially for optimizers that iterates over parameters using key paths. The current implementation of some key-path-based optimizers is wrong in that it won't work with models that contain an array of parameters (tangent vectors like infinityNorm are initialized as .zero).

An earlier version of these optimizer using the deprecated AllDifferentiableVariables property would give the correct results, but would be heavyweight and inefficient because they'd need to 1. add a constraint TangentVector == AllDifferentiableVariables to optimizers, and 2. make a copy of all parameters and resetting them to .zero. Since we are deprecating AllDifferentiableVariables, this is not the right direction.

This problem also means that our Differentiable abstraction needs to provide a general mechanism of obtaining a zero tangent vector at a certain instance. Hence we add a zeroTangentVector property to the Differentiable protocol.

Zero tangent vectors do not have a canonical mathematical definition, but makes sense for Differentiable in the standard library because Swift does not have dependent types and thus cannot have a TangentVector that depends on a point on a differentiable manifold. Manopt also has an API, M.zerovec(x), that creates a zero tangent vector at a point (see their API doc here).

Adding zeroTangentVector will make it possible to deprecate AllDifferentiableVariables completely, because currently some fast.ai notebooks depend on initializing parameter gradients using AllDifferentiableVariables.

The new Differentiable protocol looks like the following. The design overview has been updated to reflect this change. This patch also updates doc comments on Differentiable, TangentVector, and move(along:) to match the design overview.

protocol Differentiable {
    /// ...
    associatedtype TangentVector: Differentiable & AdditiveArithmetic
    /// ...
    mutating func move(along direction: TangentVector)

    /// A tangent vector such that `move(along: zeroTangentVector)` will not
    /// modify `self`.
    /// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
    ///   but types whose tangent vectors depend on instance properties of `self`
    ///   need to provide a different implementation. For example, the tangent
    ///   vector of an `Array` depends on the array’s `count`.
    var zeroTangentVector: TangentVector { get }
}

TODO: Add derived conformances for zeroTangentVector, and remove @available and the default implementation.

Partially resolves TF-708.

…le'.

Zero tangent vector is necessary for optimizations on models with an array of parameters, especially for optimizers that iterates over parameters using key paths. The [current implementation](https://github.com/tensorflow/swift-apis/blob/master/Sources/TensorFlow/Optimizers/MomentumBased.swift) of some key-path-based optimizers is wrong in that it won't work with models that contain an array of parameters (tangent vectors like `infinityNorm` are initialized as `.zero`).

An earlier version of these optimizer using the deprecated `AllDifferentiableVariables` property would give the correct results, but would be heavyweight and inefficient because they'd need to 1. add a constraint `TangentVector == AllDifferentiableVariables` to optimizers, and 2. make a copy of all parameters and resetting them to `.zero`. Since we are deprecating `AllDifferentiableVariables`, this is not the right direction.

This problem also means that our `Differentiable` abstraction needs to provide a general mechanism of obtaining a zero tangent vector at a certain instance. Hence we add a `zeroTangentVector` property to the `Differentiable` protocol.

Zero tangent vectors do not have a canonical mathematical definition, but makes sense for `Differentiable` in the standard library because Swift does not have dependent types and thus cannot have a `TangentVector` that depends on a point on a differentiable manifold. Manopt also has an API, `M.zerovec(x)`, that creates a zero tangent vector at a point (see their API doc [here](https://www.manopt.org/tutorial.html)).

Adding `zeroTangentVector` will make it possible to deprecate `AllDifferentiableVariables` completely, because currently some fast.ai notebooks depend on initializing parameter gradients using `AllDifferentiableVariables`.

The new `Differentiable` protocol looks like the following. The [design overview](http://bit.ly/swift-autodiff) has been updated to reflect this change.

```swift
protocol Differentiable {
    /// ...
    associatedtype TangentVector: Differentiable & AdditiveArithmetic
    /// ...
    mutating func move(along direction: TangentVector)

    /// A tangent vector such that `move(along: zeroTangentVector)` will not
    /// modify `self`.
    /// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
    ///   but types whose tangent vectors depend on instance properties of `self`
    ///   need to provide a different implementation. For example, the tangent
    ///   vector of an `Array` depends on the array’s `count`.
    var zeroTangentVector: TangentVector { get }
}
```

TODO: Add derived conformances for `zeroTangentVector`, and remove `@available` and the default implementation.

Partially resolves [TF-708](https://bugs.swift.org/browse/TF-708).
@rxwei rxwei added the tensorflow This is for "tensorflow" branch PRs. label Aug 25, 2019
/// but types whose tangent vectors depend on instance properties of `self`
/// need to provide a different implementation. For example, the tangent
/// vector of an `Array` depends on the array’s `count`.
@available(*, deprecated, message: """
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be deprecated because protocol requirements cannot be unavailable, and deprecated is the only way to produce a warning.

@rxwei rxwei requested a review from dan-zheng August 25, 2019 07:49
@rxwei
Copy link
Contributor Author

rxwei commented Aug 25, 2019

@swift-ci please test tensorflow

Copy link
Contributor

@dan-zheng dan-zheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zeroTangentVector stub LGTM.

@rxwei rxwei merged commit daa67ca into swiftlang:tensorflow Aug 25, 2019
@rxwei rxwei deleted the zerotangentvector-without-derivedconformances branch August 25, 2019 08:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants