You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[AutoDiff] [stdlib] Add 'zeroTangentVector' property to 'Differentiable'. (#26828)
Zero tangent vector is necessary for optimizations on models with an array of parameters, especially for optimizers that iterates over parameters using key paths. The [current implementation](https://github.com/tensorflow/swift-apis/blob/master/Sources/TensorFlow/Optimizers/MomentumBased.swift) of some key-path-based optimizers is wrong in that it won't work with models that contain an array of parameters (tangent vectors like `infinityNorm` are initialized as `.zero`).
An earlier version of these optimizer using the deprecated `AllDifferentiableVariables` property would give the correct results, but would be heavyweight and inefficient because they'd need to 1. add a constraint `TangentVector == AllDifferentiableVariables` to optimizers, and 2. make a copy of all parameters and resetting them to `.zero`. Since we are deprecating `AllDifferentiableVariables`, this is not the right direction.
This problem also means that our `Differentiable` abstraction needs to provide a general mechanism of obtaining a zero tangent vector at a certain instance. Hence we add a `zeroTangentVector` property to the `Differentiable` protocol.
Zero tangent vectors do not have a canonical mathematical definition, but makes sense for `Differentiable` in the standard library because Swift does not have dependent types and thus cannot have a `TangentVector` that depends on a point on a differentiable manifold. Manopt also has an API, `M.zerovec(x)`, that creates a zero tangent vector at a point (see their API doc [here](https://www.manopt.org/tutorial.html)).
Adding `zeroTangentVector` will make it possible to deprecate `AllDifferentiableVariables` completely, because currently some fast.ai notebooks depend on initializing parameter gradients using `AllDifferentiableVariables`.
The new `Differentiable` protocol looks like the following. The [design overview](http://bit.ly/swift-autodiff) has been updated to reflect this change.
```swift
protocol Differentiable {
/// ...
associatedtype TangentVector: Differentiable & AdditiveArithmetic
/// ...
mutating func move(along direction: TangentVector)
/// A tangent vector such that `move(along: zeroTangentVector)` will not
/// modify `self`.
/// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
/// but types whose tangent vectors depend on instance properties of `self`
/// need to provide a different implementation. For example, the tangent
/// vector of an `Array` depends on the array’s `count`.
var zeroTangentVector: TangentVector { get }
}
```
TODO: Add derived conformances for `zeroTangentVector`, and remove `@available` and the default implementation.
Partially resolves [TF-708](https://bugs.swift.org/browse/TF-708).
0 commit comments