Skip to content

Commit e3cc71f

Browse files
committed
[AutoDiff] Deprecate Differentiable.AllDifferentiableVariables.
Deprecate the `AllDifferentiableVariables` associated type and `var allDifferentiableVariables` property of the `Differentiable` protocol. `AllDifferentiableVariables` is not essential for differentiable programming and was added as a workaround to enable key-path-based machine learning optimizers: let parameters and gradients have same type (`AllDifferentiableVariables == TangentVector`) to enable joint key-path iteration. It is possible to implement key-path-based machine learning optimizers via other means (do key-path-based operations on `TangentVector`, then call `Differentiable.move(along:)` to perform update), so `AllDifferentiableVariables` is no longer necessary. Resolves TF-707.
1 parent 6773a69 commit e3cc71f

File tree

10 files changed

+152
-579
lines changed

10 files changed

+152
-579
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,7 @@ IDENTIFIER(by)
155155
IDENTIFIER(scale)
156156
IDENTIFIER(x)
157157
// Differentiable
158-
IDENTIFIER(AllDifferentiableVariables)
159158
IDENTIFIER(TangentVector)
160-
IDENTIFIER(allDifferentiableVariables)
161159
IDENTIFIER(move)
162160

163161
// Kinds of layout constraints

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 100 additions & 399 deletions
Large diffs are not rendered by default.

lib/Sema/DerivedConformances.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,6 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
285285
if (name.isSimpleName(ctx.Id_typeList))
286286
return getRequirement(KnownProtocolKind::TensorGroup);
287287

288-
// SWIFT_ENABLE_TENSORFLOW
289-
// Differentiable.allDifferentiableVariables
290-
if (name.isSimpleName(ctx.Id_allDifferentiableVariables))
291-
return getRequirement(KnownProtocolKind::Differentiable);
292-
293288
return nullptr;
294289
}
295290

@@ -454,9 +449,7 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
454449

455450
// SWIFT_ENABLE_TENSORFLOW
456451
// Differentiable.TangentVector
457-
// Differentiable.AllDifferentiableVariables
458-
if (name.isSimpleName(ctx.Id_TangentVector) ||
459-
name.isSimpleName(ctx.Id_AllDifferentiableVariables))
452+
if (name.isSimpleName(ctx.Id_TangentVector))
460453
return getRequirement(KnownProtocolKind::Differentiable);
461454

462455
// SWIFT_ENABLE_TENSORFLOW

stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,12 @@ public struct Tracked<T> {
6565
}
6666
private var handle: Box
6767

68-
@differentiable(
69-
vjp: _vjpInit
70-
where T : Differentiable, T == T.AllDifferentiableVariables,
71-
T == T.TangentVector
72-
)
68+
@differentiable(vjp: _vjpInit where T : Differentiable, T == T.TangentVector)
7369
public init(_ value: T) {
7470
self.handle = Box(value)
7571
}
7672

77-
@differentiable(
78-
vjp: _vjpValue
79-
where T : Differentiable, T == T.AllDifferentiableVariables,
80-
T == T.TangentVector
81-
)
73+
@differentiable(vjp: _vjpValue where T : Differentiable, T == T.TangentVector)
8274
public var value: T {
8375
get { handle.value }
8476
set { handle.value = newValue }
@@ -174,17 +166,11 @@ extension Tracked : Strideable where T : Strideable, T.Stride == T.Stride.Magnit
174166
}
175167

176168
// For now, `T` must be restricted to trivial types (like `Float` or `Tensor`).
177-
extension Tracked : Differentiable
178-
where T : Differentiable, T == T.AllDifferentiableVariables,
179-
T == T.TangentVector
180-
{
181-
public typealias AllDifferentiableVariables = Tracked<T.AllDifferentiableVariables>
169+
extension Tracked : Differentiable where T : Differentiable, T == T.TangentVector {
182170
public typealias TangentVector = Tracked<T.TangentVector>
183171
}
184172

185-
extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
186-
T == T.TangentVector
187-
{
173+
extension Tracked where T : Differentiable, T == T.TangentVector {
188174
@usableFromInline
189175
internal static func _vjpInit(_ value: T)
190176
-> (value: Self, pullback: (Self.TangentVector) -> (T.TangentVector)) {
@@ -197,9 +183,7 @@ extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
197183
}
198184
}
199185

200-
extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
201-
T == T.TangentVector
202-
{
186+
extension Tracked where T : Differentiable, T == T.TangentVector {
203187
@usableFromInline
204188
@differentiating(+)
205189
internal static func _vjpAdd(lhs: Self, rhs: Self)
@@ -216,7 +200,7 @@ extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
216200
}
217201

218202
extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
219-
T == T.AllDifferentiableVariables, T == T.TangentVector {
203+
T == T.TangentVector {
220204
@usableFromInline
221205
@differentiating(*)
222206
internal static func _vjpMultiply(lhs: Self, rhs: Self)
@@ -225,8 +209,7 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
225209
}
226210
}
227211

228-
extension Tracked where T : Differentiable & FloatingPoint,
229-
T == T.AllDifferentiableVariables, T == T.TangentVector {
212+
extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector {
230213
@usableFromInline
231214
@differentiating(/)
232215
internal static func _vjpDivide(lhs: Self, rhs: Self)

stdlib/public/core/Array.swift

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,26 +1952,6 @@ extension Array where Element : Differentiable {
19521952

19531953
public typealias TangentVector =
19541954
Array<Element.TangentVector>.DifferentiableView
1955-
public typealias AllDifferentiableVariables =
1956-
Array<Element.AllDifferentiableVariables>.DifferentiableView
1957-
1958-
public var allDifferentiableVariables: AllDifferentiableVariables {
1959-
get {
1960-
return AllDifferentiableVariables(
1961-
base.map { $0.allDifferentiableVariables })
1962-
}
1963-
set {
1964-
precondition(
1965-
base.count == newValue.base.count,
1966-
"cannot set Array.DifferentiableView.AllDifferentiableVariables " +
1967-
"with count \(base.count) to " +
1968-
"Array.DifferentiableView.AllDifferentiableVariables with " +
1969-
"different count \(newValue.base.count)")
1970-
for i in base.indices {
1971-
base[i].allDifferentiableVariables = newValue.base[i]
1972-
}
1973-
}
1974-
}
19751955

19761956
public mutating func move(along direction: TangentVector) {
19771957
precondition(
@@ -2066,27 +2046,13 @@ extension Array.DifferentiableView : AdditiveArithmetic
20662046
/// Makes `Array` differentiable as the product manifold of `Element`
20672047
/// multiplied with itself `count` times.
20682048
extension Array : Differentiable where Element : Differentiable {
2069-
// In an ideal world, `TangentVector`, `TangentVector`, and
2070-
// `AllDifferentiableVariables` would all be `Array`s. Unfortunately, we
2071-
// can't conform `Array` to `AdditiveArithmetic` for `TangentVector` and
2072-
// `TangentVector`, because `Array` already has a static `+` method with
2073-
// different semantics from `AdditiveArithmetic` `+`. So we use
2049+
// In an ideal world, `TangentVector` would be `[Element.TangentVector]`.
2050+
// Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for
2051+
// `TangentVector` because `Array` already has a static `+` method with
2052+
// different semantics from `AdditiveArithmetic.+`. So we use
20742053
// `Array.DifferentiableView` for all these associated types.
20752054
public typealias TangentVector =
20762055
Array<Element.TangentVector>.DifferentiableView
2077-
public typealias AllDifferentiableVariables =
2078-
Array<Element.AllDifferentiableVariables>.DifferentiableView
2079-
2080-
public var allDifferentiableVariables: AllDifferentiableVariables {
2081-
get {
2082-
return DifferentiableView(self).allDifferentiableVariables
2083-
}
2084-
set {
2085-
var view = DifferentiableView(self)
2086-
view.allDifferentiableVariables = newValue
2087-
self = view.base
2088-
}
2089-
}
20902056

20912057
public mutating func move(along direction: TangentVector) {
20922058
var view = DifferentiableView(self)

stdlib/public/core/AutoDiff.swift

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -153,26 +153,24 @@ public extension VectorProtocol where VectorSpaceScalar : SignedNumeric {
153153
/// tangent spaces are finite-dimensional.
154154
public protocol Differentiable {
155155
associatedtype TangentVector: Differentiable & AdditiveArithmetic
156-
where TangentVector.TangentVector == TangentVector,
157-
AllDifferentiableVariables.AllDifferentiableVariables ==
158-
AllDifferentiableVariables,
159-
AllDifferentiableVariables.TangentVector == TangentVector
160-
/// The type of all differentiable variables in this type.
161-
associatedtype AllDifferentiableVariables : Differentiable
162-
163-
/// All differentiable variables of this value.
164-
var allDifferentiableVariables: AllDifferentiableVariables { get set }
156+
where TangentVector.TangentVector == TangentVector
165157

166158
/// Moves `self` along the value space towards the given tangent vector. In
167159
/// Riemannian geometry (mathematics), this represents an exponential map.
168160
mutating func move(along direction: TangentVector)
169161

162+
@available(*, deprecated,
163+
message: "'AllDifferentiableVariables' is now equal to 'Self' and will be removed")
164+
typealias AllDifferentiableVariables = Self
165+
170166
@available(*, deprecated,
171167
message: "'CotangentVector' is now equal to 'TangentVector' and will be removed")
172168
typealias CotangentVector = TangentVector
173169
}
174170

175-
public extension Differentiable where AllDifferentiableVariables == Self {
171+
public extension Differentiable {
172+
@available(*, deprecated,
173+
message: "'allDifferentiableVariables' is now equal to 'self' and will be removed")
176174
var allDifferentiableVariables: AllDifferentiableVariables {
177175
get { return self }
178176
set { self = newValue }
@@ -731,8 +729,7 @@ internal protocol _AnyDerivativeBox {
731729

732730
/// Returns the underlying value unboxed to the given type, if possible.
733731
func _unboxed<U>(to type: U.Type) -> U?
734-
where U : Differentiable, U.TangentVector == U,
735-
U.AllDifferentiableVariables == U
732+
where U : Differentiable, U.TangentVector == U
736733
}
737734

738735
extension _AnyDerivativeBox {
@@ -754,8 +751,7 @@ internal func _derivativeTypeMismatch(
754751
}
755752

756753
internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
757-
where T : Differentiable, T.TangentVector == T,
758-
T.AllDifferentiableVariables == T
754+
where T : Differentiable, T.TangentVector == T
759755
{
760756
/// The underlying base value.
761757
var _base: T
@@ -770,8 +766,7 @@ internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
770766
}
771767

772768
func _unboxed<U>(to type: U.Type) -> U?
773-
where U : Differentiable, U.TangentVector == U,
774-
U.AllDifferentiableVariables == U
769+
where U : Differentiable, U.TangentVector == U
775770
{
776771
return (self as? _ConcreteDerivativeBox<U>)?._base
777772
}
@@ -861,24 +856,19 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
861856

862857
/// Creates a type-erased derivative from the given derivative.
863858
@differentiable(vjp: _vjpInit(_:))
864-
public init<T>(_ base: T)
865-
where T : Differentiable, T.TangentVector == T,
866-
T.AllDifferentiableVariables == T
867-
{
859+
public init<T>(_ base: T) where T : Differentiable, T.TangentVector == T {
868860
self._box = _ConcreteDerivativeBox<T>(base)
869861
}
870862

871863
@usableFromInline internal static func _vjpInit<T>(
872864
_ base: T
873865
) -> (AnyDerivative, (AnyDerivative) -> T.TangentVector)
874-
where T : Differentiable, T.TangentVector == T,
875-
T.AllDifferentiableVariables == T
866+
where T : Differentiable, T.TangentVector == T
876867
{
877868
return (AnyDerivative(base), { v in v.base as! T.TangentVector })
878869
}
879870

880871
public typealias TangentVector = AnyDerivative
881-
public typealias AllDifferentiableVariables = AnyDerivative
882872

883873
// `Equatable` requirements (implied by `AdditiveArithmetic`).
884874
public static func == (lhs: AnyDerivative, rhs: AnyDerivative) -> Bool {
@@ -929,10 +919,6 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
929919
}
930920

931921
// `Differentiable` requirements.
932-
public var allDifferentiableVariables: AllDifferentiableVariables {
933-
get { return AnyDerivative(_box: _box._allDifferentiableVariables) }
934-
// set { _box._allDifferentiableVariables = newValue._box }
935-
}
936922
public mutating func move(along direction: TangentVector) {
937923
if _box._isOpaqueZero() {
938924
_box = direction._box

stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1906,7 +1906,6 @@ extension ${Self} : VectorProtocol {
19061906

19071907
extension ${Self} : Differentiable {
19081908
public typealias TangentVector = ${Self}
1909-
public typealias AllDifferentiableVariables = ${Self}
19101909

19111910
public mutating func move(along direction: TangentVector) {
19121911
self += direction

stdlib/public/core/SIMDVectorTypes.swift.gyb

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,6 @@ extension SIMD${n} : Differentiable
196196
where Scalar : Differentiable & BinaryFloatingPoint,
197197
Scalar.TangentVector : BinaryFloatingPoint {
198198
public typealias TangentVector = SIMD${n}
199-
public typealias AllDifferentiableVariables = SIMD${n}
200-
public func tangentVector(from cotangent: TangentVector) -> TangentVector {
201-
return cotangent
202-
}
203199
}
204200

205201
extension SIMD${n}

0 commit comments

Comments
 (0)