[AutoDiff] [stdlib] Add 'zeroTangentVector' property to 'Differentiable'. #26828
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Zero tangent vector is necessary for optimizations on models with an array of parameters, especially for optimizers that iterates over parameters using key paths. The current implementation of some key-path-based optimizers is wrong in that it won't work with models that contain an array of parameters (tangent vectors like
infinityNorm
are initialized as.zero
).An earlier version of these optimizer using the deprecated
AllDifferentiableVariables
property would give the correct results, but would be heavyweight and inefficient because they'd need to 1. add a constraintTangentVector == AllDifferentiableVariables
to optimizers, and 2. make a copy of all parameters and resetting them to.zero
. Since we are deprecatingAllDifferentiableVariables
, this is not the right direction.This problem also means that our
Differentiable
abstraction needs to provide a general mechanism of obtaining a zero tangent vector at a certain instance. Hence we add azeroTangentVector
property to theDifferentiable
protocol.Zero tangent vectors do not have a canonical mathematical definition, but makes sense for
Differentiable
in the standard library because Swift does not have dependent types and thus cannot have aTangentVector
that depends on a point on a differentiable manifold. Manopt also has an API,M.zerovec(x)
, that creates a zero tangent vector at a point (see their API doc here).Adding
zeroTangentVector
will make it possible to deprecateAllDifferentiableVariables
completely, because currently some fast.ai notebooks depend on initializing parameter gradients usingAllDifferentiableVariables
.The new
Differentiable
protocol looks like the following. The design overview has been updated to reflect this change. This patch also updates doc comments onDifferentiable
,TangentVector
, andmove(along:)
to match the design overview.TODO: Add derived conformances for
zeroTangentVector
, and remove@available
and the default implementation.Partially resolves TF-708.