Skip to content

Commit daa67ca

Browse files
authored
[AutoDiff] [stdlib] Add 'zeroTangentVector' property to 'Differentiable'. (#26828)
Zero tangent vector is necessary for optimizations on models with an array of parameters, especially for optimizers that iterates over parameters using key paths. The [current implementation](https://github.com/tensorflow/swift-apis/blob/master/Sources/TensorFlow/Optimizers/MomentumBased.swift) of some key-path-based optimizers is wrong in that it won't work with models that contain an array of parameters (tangent vectors like `infinityNorm` are initialized as `.zero`). An earlier version of these optimizer using the deprecated `AllDifferentiableVariables` property would give the correct results, but would be heavyweight and inefficient because they'd need to 1. add a constraint `TangentVector == AllDifferentiableVariables` to optimizers, and 2. make a copy of all parameters and resetting them to `.zero`. Since we are deprecating `AllDifferentiableVariables`, this is not the right direction. This problem also means that our `Differentiable` abstraction needs to provide a general mechanism of obtaining a zero tangent vector at a certain instance. Hence we add a `zeroTangentVector` property to the `Differentiable` protocol. Zero tangent vectors do not have a canonical mathematical definition, but makes sense for `Differentiable` in the standard library because Swift does not have dependent types and thus cannot have a `TangentVector` that depends on a point on a differentiable manifold. Manopt also has an API, `M.zerovec(x)`, that creates a zero tangent vector at a point (see their API doc [here](https://www.manopt.org/tutorial.html)). Adding `zeroTangentVector` will make it possible to deprecate `AllDifferentiableVariables` completely, because currently some fast.ai notebooks depend on initializing parameter gradients using `AllDifferentiableVariables`. The new `Differentiable` protocol looks like the following. The [design overview](http://bit.ly/swift-autodiff) has been updated to reflect this change. ```swift protocol Differentiable { /// ... associatedtype TangentVector: Differentiable & AdditiveArithmetic /// ... mutating func move(along direction: TangentVector) /// A tangent vector such that `move(along: zeroTangentVector)` will not /// modify `self`. /// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases, /// but types whose tangent vectors depend on instance properties of `self` /// need to provide a different implementation. For example, the tangent /// vector of an `Array` depends on the array’s `count`. var zeroTangentVector: TangentVector { get } } ``` TODO: Add derived conformances for `zeroTangentVector`, and remove `@available` and the default implementation. Partially resolves [TF-708](https://bugs.swift.org/browse/TF-708).
1 parent 5e12ce4 commit daa67ca

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

stdlib/public/core/AutoDiff.swift

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,30 @@ public extension VectorProtocol where VectorSpaceScalar : SignedNumeric {
152152
/// A type that mathematically represents a differentiable manifold whose
153153
/// tangent spaces are finite-dimensional.
154154
public protocol Differentiable {
155+
/// A type representing a differentiable value’s derivatives.
156+
///
157+
/// Mathematically, this is equivalent to the tangent bundle of the
158+
/// differentiable manifold represented by the differentiable type.
155159
associatedtype TangentVector: Differentiable & AdditiveArithmetic
156160
where TangentVector.TangentVector == TangentVector
157161

158-
/// Moves `self` along the value space towards the given tangent vector. In
159-
/// Riemannian geometry (mathematics), this represents an exponential map.
162+
/// Moves `self` along the given direction. In Riemannian geometry, this is
163+
/// equivalent to exponential map, which moves `self` on the geodesic surface
164+
/// along the given tangent vector.
160165
mutating func move(along direction: TangentVector)
161166

167+
/// A tangent vector such that `move(along: zeroTangentVector)` will not
168+
/// modify `self`.
169+
/// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
170+
/// but types whose tangent vectors depend on instance properties of `self`
171+
/// need to provide a different implementation. For example, the tangent
172+
/// vector of an `Array` depends on the array’s `count`.
173+
@available(*, deprecated, message: """
174+
`zeroTangentVector` derivation has not been implemented; do not use \
175+
this property
176+
""")
177+
var zeroTangentVector: TangentVector { get }
178+
162179
@available(*, deprecated,
163180
message: "'AllDifferentiableVariables' is now equal to 'Self' and will be removed")
164181
typealias AllDifferentiableVariables = Self
@@ -175,6 +192,14 @@ public extension Differentiable {
175192
get { return self }
176193
set { self = newValue }
177194
}
195+
196+
// This is a temporary solution that allows us to add `zeroTangentVector`
197+
// without implementing derived conformances. This property is marked
198+
// unavailable because it will produce incorrect results when tangent vectors
199+
// depend on instance properties of `self`.
200+
// FIXME: Implement derived conformance and remove this default
201+
// implementation.
202+
var zeroTangentVector: TangentVector { .zero }
178203
}
179204

180205
public extension Differentiable where TangentVector == Self {

0 commit comments

Comments
 (0)