You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Some iterative mathematical optimization algorithms (e.g. gradient descent with weight decay), when applied on a model (of type `Model` that conforms to `Differentiable`), rely on computing tangent vectors (of type `Model.TangentVector`) based on a gradient and an existing parameter (a stored property of `Model`). However, models that are differentiable via Euclidean differentiation are not necessarily vector spaces such that the `Model` equals `Model.TangentVector`, because `Model` may contain properties that do not appear in `Model.TangentVector`, such as Boolean flags or other hyperparameters. This use case is extremely common in machine learning, where most differentiation happens in the Euclidean space.
To make this use case possible, we introduce `EuclideanDifferentiable`, which generalizes types that are a product manifold of a differentiable vector space and some arbitrary manifold where the product manifold's tangent space is equal to the differentiable vector space component. In other words, `T: EuclideanDifferentiable` means that `T == V × M` where `V: Differentiable & AdditiveArithmetic` and `T.TangentVector == V == V.TangentVector`. The protocol has a `vectorView` property, which returns the `V` part (the differentiable vector space part) of the value.
The `EuclideanDifferentiable` allows us to **generically** express the kind of special iterative optimization algorithms that use a models parameters when computing a tangent vector.
```swift
let 𝛁L: Model.TangentVector = ...
model.move(along: -η * 𝛁L - η * λ * model.vectorView))
```
TODO: We need to implement derived conformances for the `vectorView` property.
Thanks to @sgugger, @marcrasi, @mattjj, @dougalm, and @caseychu for their input on solving this use case.
Partially resolves [tensorflow/swift-apis#456](tensorflow/swift-apis#456).
0 commit comments