Skip to content

Commit 2e582b9

Browse files
authored
[AutoDiff] Deprecate Differentiable.AllDifferentiableVariables. (#26527)
Deprecate the `AllDifferentiableVariables` associated type and `var allDifferentiableVariables` property of the `Differentiable` protocol. `AllDifferentiableVariables` is not essential for differentiable programming and was added as a workaround to enable key-path-based machine learning optimizers: let parameters and gradients have the same type (`AllDifferentiableVariables == TangentVector`) to enable joint key-path iteration. It is possible to implement key-path-based machine learning optimizers via other means (do key-path-based operations on `TangentVector`, then call `Differentiable.move(along:)` to perform update), so `AllDifferentiableVariables` is no longer necessary. Resolves TF-707.
1 parent 6773a69 commit 2e582b9

18 files changed

+175
-618
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,7 @@ IDENTIFIER(by)
155155
IDENTIFIER(scale)
156156
IDENTIFIER(x)
157157
// Differentiable
158-
IDENTIFIER(AllDifferentiableVariables)
159158
IDENTIFIER(TangentVector)
160-
IDENTIFIER(allDifferentiableVariables)
161159
IDENTIFIER(move)
162160

163161
// Kinds of layout constraints

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 103 additions & 398 deletions
Large diffs are not rendered by default.

lib/Sema/DerivedConformances.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,6 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
285285
if (name.isSimpleName(ctx.Id_typeList))
286286
return getRequirement(KnownProtocolKind::TensorGroup);
287287

288-
// SWIFT_ENABLE_TENSORFLOW
289-
// Differentiable.allDifferentiableVariables
290-
if (name.isSimpleName(ctx.Id_allDifferentiableVariables))
291-
return getRequirement(KnownProtocolKind::Differentiable);
292-
293288
return nullptr;
294289
}
295290

@@ -454,9 +449,7 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
454449

455450
// SWIFT_ENABLE_TENSORFLOW
456451
// Differentiable.TangentVector
457-
// Differentiable.AllDifferentiableVariables
458-
if (name.isSimpleName(ctx.Id_TangentVector) ||
459-
name.isSimpleName(ctx.Id_AllDifferentiableVariables))
452+
if (name.isSimpleName(ctx.Id_TangentVector))
460453
return getRequirement(KnownProtocolKind::Differentiable);
461454

462455
// SWIFT_ENABLE_TENSORFLOW

stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,12 @@ public struct Tracked<T> {
6565
}
6666
private var handle: Box
6767

68-
@differentiable(
69-
vjp: _vjpInit
70-
where T : Differentiable, T == T.AllDifferentiableVariables,
71-
T == T.TangentVector
72-
)
68+
@differentiable(vjp: _vjpInit where T : Differentiable, T == T.TangentVector)
7369
public init(_ value: T) {
7470
self.handle = Box(value)
7571
}
7672

77-
@differentiable(
78-
vjp: _vjpValue
79-
where T : Differentiable, T == T.AllDifferentiableVariables,
80-
T == T.TangentVector
81-
)
73+
@differentiable(vjp: _vjpValue where T : Differentiable, T == T.TangentVector)
8274
public var value: T {
8375
get { handle.value }
8476
set { handle.value = newValue }
@@ -174,17 +166,11 @@ extension Tracked : Strideable where T : Strideable, T.Stride == T.Stride.Magnit
174166
}
175167

176168
// For now, `T` must be restricted to trivial types (like `Float` or `Tensor`).
177-
extension Tracked : Differentiable
178-
where T : Differentiable, T == T.AllDifferentiableVariables,
179-
T == T.TangentVector
180-
{
181-
public typealias AllDifferentiableVariables = Tracked<T.AllDifferentiableVariables>
169+
extension Tracked : Differentiable where T : Differentiable, T == T.TangentVector {
182170
public typealias TangentVector = Tracked<T.TangentVector>
183171
}
184172

185-
extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
186-
T == T.TangentVector
187-
{
173+
extension Tracked where T : Differentiable, T == T.TangentVector {
188174
@usableFromInline
189175
internal static func _vjpInit(_ value: T)
190176
-> (value: Self, pullback: (Self.TangentVector) -> (T.TangentVector)) {
@@ -197,9 +183,7 @@ extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
197183
}
198184
}
199185

200-
extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
201-
T == T.TangentVector
202-
{
186+
extension Tracked where T : Differentiable, T == T.TangentVector {
203187
@usableFromInline
204188
@differentiating(+)
205189
internal static func _vjpAdd(lhs: Self, rhs: Self)
@@ -216,7 +200,7 @@ extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
216200
}
217201

218202
extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
219-
T == T.AllDifferentiableVariables, T == T.TangentVector {
203+
T == T.TangentVector {
220204
@usableFromInline
221205
@differentiating(*)
222206
internal static func _vjpMultiply(lhs: Self, rhs: Self)
@@ -225,8 +209,7 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
225209
}
226210
}
227211

228-
extension Tracked where T : Differentiable & FloatingPoint,
229-
T == T.AllDifferentiableVariables, T == T.TangentVector {
212+
extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector {
230213
@usableFromInline
231214
@differentiating(/)
232215
internal static func _vjpDivide(lhs: Self, rhs: Self)

stdlib/public/core/Array.swift

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,26 +1952,6 @@ extension Array where Element : Differentiable {
19521952

19531953
public typealias TangentVector =
19541954
Array<Element.TangentVector>.DifferentiableView
1955-
public typealias AllDifferentiableVariables =
1956-
Array<Element.AllDifferentiableVariables>.DifferentiableView
1957-
1958-
public var allDifferentiableVariables: AllDifferentiableVariables {
1959-
get {
1960-
return AllDifferentiableVariables(
1961-
base.map { $0.allDifferentiableVariables })
1962-
}
1963-
set {
1964-
precondition(
1965-
base.count == newValue.base.count,
1966-
"cannot set Array.DifferentiableView.AllDifferentiableVariables " +
1967-
"with count \(base.count) to " +
1968-
"Array.DifferentiableView.AllDifferentiableVariables with " +
1969-
"different count \(newValue.base.count)")
1970-
for i in base.indices {
1971-
base[i].allDifferentiableVariables = newValue.base[i]
1972-
}
1973-
}
1974-
}
19751955

19761956
public mutating func move(along direction: TangentVector) {
19771957
precondition(
@@ -2066,27 +2046,13 @@ extension Array.DifferentiableView : AdditiveArithmetic
20662046
/// Makes `Array` differentiable as the product manifold of `Element`
20672047
/// multiplied with itself `count` times.
20682048
extension Array : Differentiable where Element : Differentiable {
2069-
// In an ideal world, `TangentVector`, `TangentVector`, and
2070-
// `AllDifferentiableVariables` would all be `Array`s. Unfortunately, we
2071-
// can't conform `Array` to `AdditiveArithmetic` for `TangentVector` and
2072-
// `TangentVector`, because `Array` already has a static `+` method with
2073-
// different semantics from `AdditiveArithmetic` `+`. So we use
2049+
// In an ideal world, `TangentVector` would be `[Element.TangentVector]`.
2050+
// Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for
2051+
// `TangentVector` because `Array` already has a static `+` method with
2052+
// different semantics from `AdditiveArithmetic.+`. So we use
20742053
// `Array.DifferentiableView` for all these associated types.
20752054
public typealias TangentVector =
20762055
Array<Element.TangentVector>.DifferentiableView
2077-
public typealias AllDifferentiableVariables =
2078-
Array<Element.AllDifferentiableVariables>.DifferentiableView
2079-
2080-
public var allDifferentiableVariables: AllDifferentiableVariables {
2081-
get {
2082-
return DifferentiableView(self).allDifferentiableVariables
2083-
}
2084-
set {
2085-
var view = DifferentiableView(self)
2086-
view.allDifferentiableVariables = newValue
2087-
self = view.base
2088-
}
2089-
}
20902056

20912057
public mutating func move(along direction: TangentVector) {
20922058
var view = DifferentiableView(self)

stdlib/public/core/AutoDiff.swift

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -153,26 +153,24 @@ public extension VectorProtocol where VectorSpaceScalar : SignedNumeric {
153153
/// tangent spaces are finite-dimensional.
154154
public protocol Differentiable {
155155
associatedtype TangentVector: Differentiable & AdditiveArithmetic
156-
where TangentVector.TangentVector == TangentVector,
157-
AllDifferentiableVariables.AllDifferentiableVariables ==
158-
AllDifferentiableVariables,
159-
AllDifferentiableVariables.TangentVector == TangentVector
160-
/// The type of all differentiable variables in this type.
161-
associatedtype AllDifferentiableVariables : Differentiable
162-
163-
/// All differentiable variables of this value.
164-
var allDifferentiableVariables: AllDifferentiableVariables { get set }
156+
where TangentVector.TangentVector == TangentVector
165157

166158
/// Moves `self` along the value space towards the given tangent vector. In
167159
/// Riemannian geometry (mathematics), this represents an exponential map.
168160
mutating func move(along direction: TangentVector)
169161

162+
@available(*, deprecated,
163+
message: "'AllDifferentiableVariables' is now equal to 'Self' and will be removed")
164+
typealias AllDifferentiableVariables = Self
165+
170166
@available(*, deprecated,
171167
message: "'CotangentVector' is now equal to 'TangentVector' and will be removed")
172168
typealias CotangentVector = TangentVector
173169
}
174170

175-
public extension Differentiable where AllDifferentiableVariables == Self {
171+
public extension Differentiable {
172+
@available(*, deprecated,
173+
message: "'allDifferentiableVariables' is now equal to 'self' and will be removed")
176174
var allDifferentiableVariables: AllDifferentiableVariables {
177175
get { return self }
178176
set { self = newValue }
@@ -723,16 +721,14 @@ internal protocol _AnyDerivativeBox {
723721
func _subtracting(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox
724722

725723
// `Differentiable` requirements.
726-
var _allDifferentiableVariables: _AnyDerivativeBox { get }
727724
mutating func _move(along direction: _AnyDerivativeBox)
728725

729726
/// The underlying base value, type-erased to `Any`.
730727
var _typeErasedBase: Any { get }
731728

732729
/// Returns the underlying value unboxed to the given type, if possible.
733730
func _unboxed<U>(to type: U.Type) -> U?
734-
where U : Differentiable, U.TangentVector == U,
735-
U.AllDifferentiableVariables == U
731+
where U : Differentiable, U.TangentVector == U
736732
}
737733

738734
extension _AnyDerivativeBox {
@@ -754,8 +750,7 @@ internal func _derivativeTypeMismatch(
754750
}
755751

756752
internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
757-
where T : Differentiable, T.TangentVector == T,
758-
T.AllDifferentiableVariables == T
753+
where T : Differentiable, T.TangentVector == T
759754
{
760755
/// The underlying base value.
761756
var _base: T
@@ -770,8 +765,7 @@ internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
770765
}
771766

772767
func _unboxed<U>(to type: U.Type) -> U?
773-
where U : Differentiable, U.TangentVector == U,
774-
U.AllDifferentiableVariables == U
768+
where U : Differentiable, U.TangentVector == U
775769
{
776770
return (self as? _ConcreteDerivativeBox<U>)?._base
777771
}
@@ -824,10 +818,6 @@ internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
824818

825819
// `Differentiable` requirements.
826820

827-
var _allDifferentiableVariables: _AnyDerivativeBox {
828-
return _ConcreteDerivativeBox(_base.allDifferentiableVariables)
829-
}
830-
831821
mutating func _move(along direction: _AnyDerivativeBox) {
832822
if direction._isOpaqueZero() {
833823
return
@@ -861,24 +851,19 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
861851

862852
/// Creates a type-erased derivative from the given derivative.
863853
@differentiable(vjp: _vjpInit(_:))
864-
public init<T>(_ base: T)
865-
where T : Differentiable, T.TangentVector == T,
866-
T.AllDifferentiableVariables == T
867-
{
854+
public init<T>(_ base: T) where T : Differentiable, T.TangentVector == T {
868855
self._box = _ConcreteDerivativeBox<T>(base)
869856
}
870857

871858
@usableFromInline internal static func _vjpInit<T>(
872859
_ base: T
873860
) -> (AnyDerivative, (AnyDerivative) -> T.TangentVector)
874-
where T : Differentiable, T.TangentVector == T,
875-
T.AllDifferentiableVariables == T
861+
where T : Differentiable, T.TangentVector == T
876862
{
877863
return (AnyDerivative(base), { v in v.base as! T.TangentVector })
878864
}
879865

880866
public typealias TangentVector = AnyDerivative
881-
public typealias AllDifferentiableVariables = AnyDerivative
882867

883868
// `Equatable` requirements (implied by `AdditiveArithmetic`).
884869
public static func == (lhs: AnyDerivative, rhs: AnyDerivative) -> Bool {
@@ -929,10 +914,6 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
929914
}
930915

931916
// `Differentiable` requirements.
932-
public var allDifferentiableVariables: AllDifferentiableVariables {
933-
get { return AnyDerivative(_box: _box._allDifferentiableVariables) }
934-
// set { _box._allDifferentiableVariables = newValue._box }
935-
}
936917
public mutating func move(along direction: TangentVector) {
937918
if _box._isOpaqueZero() {
938919
_box = direction._box

stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1906,7 +1906,6 @@ extension ${Self} : VectorProtocol {
19061906

19071907
extension ${Self} : Differentiable {
19081908
public typealias TangentVector = ${Self}
1909-
public typealias AllDifferentiableVariables = ${Self}
19101909

19111910
public mutating func move(along direction: TangentVector) {
19121911
self += direction

stdlib/public/core/SIMDVectorTypes.swift.gyb

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,6 @@ extension SIMD${n} : Differentiable
196196
where Scalar : Differentiable & BinaryFloatingPoint,
197197
Scalar.TangentVector : BinaryFloatingPoint {
198198
public typealias TangentVector = SIMD${n}
199-
public typealias AllDifferentiableVariables = SIMD${n}
200-
public func tangentVector(from cotangent: TangentVector) -> TangentVector {
201-
return cotangent
202-
}
203199
}
204200

205201
extension SIMD${n}

test/AutoDiff/control_flow.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,9 +492,9 @@ ControlFlowTests.test("Enums") {
492492
return input * w1
493493
}
494494
}
495-
expectEqual((Dense.AllDifferentiableVariables(w1: 10), 20),
495+
expectEqual((Dense.TangentVector(w1: 10), 20),
496496
Dense(w1: 4, w2: 5).gradient(at: 2, in: { dense, x in dense(x) }))
497-
expectEqual((Dense.AllDifferentiableVariables(w1: 2), 4),
497+
expectEqual((Dense.TangentVector(w1: 2), 4),
498498
Dense(w1: 4, w2: nil).gradient(at: 2, in: { dense, x in dense(x) }))
499499

500500
indirect enum Indirect {

test/AutoDiff/derived_conformances.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ DerivedConformanceTests.test("MemberwiseInitializers") {
2626
@noDerivative let constant2 = Double(1)
2727
var x = Float(1)
2828
}
29-
expectEqual(HasNoDerivativeConstant.AllDifferentiableVariables(x: 0),
30-
HasNoDerivativeConstant.AllDifferentiableVariables.zero)
3129
expectEqual(HasNoDerivativeConstant.TangentVector(x: 0),
3230
HasNoDerivativeConstant.TangentVector.zero)
3331
}

test/AutoDiff/derived_differentiable.swift

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ public struct Foo : Differentiable {
1212
// CHECK-AST: @differentiable
1313
// CHECK-AST: public var a: Float
1414
// CHECK-AST: internal init(a: Float)
15-
// CHECK-AST: public struct AllDifferentiableVariables
16-
// CHECK-AST: public typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables
17-
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
18-
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
15+
// CHECK-AST: public struct TangentVector
16+
// CHECK-AST: public typealias TangentVector = Foo.TangentVector
1917

2018
// CHECK-SILGEN-LABEL: // Foo.a.getter
2119
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s22derived_differentiable3FooV1aSfvg : $@convention(method) (Foo) -> Float
@@ -33,7 +31,6 @@ let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in
3331
// CHECK-AST: internal var dummy: PointwiseMultiplicativeDummy
3432
// CHECK-AST: internal init(a: Float, dummy: PointwiseMultiplicativeDummy)
3533
// CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf
36-
// CHECK-AST: internal typealias AllDifferentiableVariables = AdditiveTangentIsSelf
3734

3835
struct TestNoDerivative : Differentiable {
3936
var w: Float
@@ -44,10 +41,8 @@ struct TestNoDerivative : Differentiable {
4441
// CHECK-AST: var w: Float
4542
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
4643
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
47-
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol
48-
// CHECK-AST: internal typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables
49-
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
50-
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
44+
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol
45+
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.TangentVector
5146

5247
struct TestPointwiseMultiplicative : Differentiable {
5348
var w: PointwiseMultiplicativeDummy
@@ -58,10 +53,8 @@ struct TestPointwiseMultiplicative : Differentiable {
5853
// CHECK-AST: var w: PointwiseMultiplicativeDummy
5954
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: PointwiseMultiplicativeDummy
6055
// CHECK-AST: internal init(w: PointwiseMultiplicativeDummy, technicallyDifferentiable: PointwiseMultiplicativeDummy)
61-
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, PointwiseMultiplicative
62-
// CHECK-AST: internal typealias AllDifferentiableVariables = TestPointwiseMultiplicative.AllDifferentiableVariables
63-
// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.AllDifferentiableVariables
64-
// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.AllDifferentiableVariables
56+
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative
57+
// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.TangentVector
6558

6659

6760
struct TestKeyPathIterable : Differentiable, KeyPathIterable {
@@ -73,10 +66,8 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable {
7366
// CHECK-AST: var w: Float
7467
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
7568
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
76-
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, ElementaryFunctions, VectorProtocol
77-
// CHECK-AST: internal typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables
78-
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
79-
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
69+
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol, KeyPathIterable
70+
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.TangentVector
8071

8172
struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic {
8273
var x: T.TangentVector
@@ -86,7 +77,6 @@ struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic
8677
// CHECK-AST: internal var x: T.TangentVector
8778
// CHECK-AST: internal init(x: T.TangentVector)
8879
// CHECK-AST: internal typealias TangentVector = GenericTanMember<T>
89-
// CHECK-AST: internal typealias AllDifferentiableVariables = GenericTanMember<T>
9080
// CHECK-AST: internal static var zero: GenericTanMember<T> { get }
9181
// CHECK-AST: internal static func + (lhs: GenericTanMember<T>, rhs: GenericTanMember<T>) -> GenericTanMember<T>
9282
// CHECK-AST: internal static func - (lhs: GenericTanMember<T>, rhs: GenericTanMember<T>) -> GenericTanMember<T>

0 commit comments

Comments
 (0)