Skip to content

Commit 5e12ce4

Browse files
authored
[AutoDiff] [stdlib] Add 'EuclideanDifferentiable' protocol. (#26827)
Some iterative mathematical optimization algorithms (e.g. gradient descent with weight decay), when applied on a model (of type `Model` that conforms to `Differentiable`), rely on computing tangent vectors (of type `Model.TangentVector`) based on a gradient and an existing parameter (a stored property of `Model`). However, models that are differentiable via Euclidean differentiation are not necessarily vector spaces such that the `Model` equals `Model.TangentVector`, because `Model` may contain properties that do not appear in `Model.TangentVector`, such as Boolean flags or other hyperparameters. This use case is extremely common in machine learning, where most differentiation happens in the Euclidean space. To make this use case possible, we introduce `EuclideanDifferentiable`, which generalizes types that are a product manifold of a differentiable vector space and some arbitrary manifold where the product manifold's tangent space is equal to the differentiable vector space component. In other words, `T: EuclideanDifferentiable` means that `T == V × M` where `V: Differentiable & AdditiveArithmetic` and `T.TangentVector == V == V.TangentVector`. The protocol has a `vectorView` property, which returns the `V` part (the differentiable vector space part) of the value. The `EuclideanDifferentiable` allows us to **generically** express the kind of special iterative optimization algorithms that use a models parameters when computing a tangent vector. ```swift let 𝛁L: Model.TangentVector = ... model.move(along: -η * 𝛁L - η * λ * model.vectorView)) ``` TODO: We need to implement derived conformances for the `vectorView` property. Thanks to @sgugger, @marcrasi, @mattjj, @dougalm, and @caseychu for their input on solving this use case. Partially resolves [tensorflow/swift-apis#456](tensorflow/swift-apis#456).
1 parent cf52777 commit 5e12ce4

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

stdlib/public/core/AutoDiff.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,41 @@ public extension Differentiable where TangentVector == Self {
183183
}
184184
}
185185

186+
/// A type that consists of a differentiable vector space and some other
187+
/// non-differentiable component.
188+
///
189+
/// Mathematically, this represents a product manifold that consists of
190+
/// a differentiable vector space and some arbitrary manifold, where the tangent
191+
/// bundle of the entire product manifold is equal to the vector space
192+
/// component.
193+
///
194+
/// This abstraction is useful for representing common differentiable data
195+
/// structures that contain both differentiable vector properties and other
196+
/// stored properties that do not have a derivative, e.g.
197+
///
198+
/// ```swift
199+
/// struct Perceptron: @memberwise EuclideanDifferentiable {
200+
/// var weight: SIMD16<Float>
201+
/// var bias: Float
202+
/// @noDerivative var useBias: Bool
203+
/// }
204+
/// ```
205+
///
206+
/// - Note: Conform a type to `EuclideanDifferentiable` if it is differentiable
207+
/// only with respect to its vector space component and when its
208+
/// `TangentVector` is equal to its vector space component.
209+
public protocol EuclideanDifferentiable: Differentiable {
210+
/// The differentiable vector component of `self`.
211+
var vectorView: TangentVector { get set }
212+
}
213+
214+
public extension EuclideanDifferentiable where TangentVector == Self {
215+
var vectorView: TangentVector {
216+
_read { yield self }
217+
_modify { yield &self }
218+
}
219+
}
220+
186221
/// Returns `x` like an identity function. When used in a context where `x` is
187222
/// being differentiated with respect to, this function will not produce any
188223
/// derivative at `x`.

0 commit comments

Comments
 (0)