Skip to content

Commit fae7287

Browse files
committed
Revert "[AutoDiff] Deprecate Differentiable.AllDifferentiableVariables. (swiftlang#26527)"
This reverts commit 2e582b9.
1 parent 5dd9c50 commit fae7287

17 files changed

+617
-174
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ IDENTIFIER(by)
155155
IDENTIFIER(scale)
156156
IDENTIFIER(x)
157157
// Differentiable
158+
IDENTIFIER(AllDifferentiableVariables)
158159
IDENTIFIER(TangentVector)
160+
IDENTIFIER(allDifferentiableVariables)
159161
IDENTIFIER(move)
160162

161163
// Kinds of layout constraints

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 398 additions & 103 deletions
Large diffs are not rendered by default.

lib/Sema/DerivedConformances.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
285285
if (name.isSimpleName(ctx.Id_typeList))
286286
return getRequirement(KnownProtocolKind::TensorGroup);
287287

288+
// SWIFT_ENABLE_TENSORFLOW
289+
// Differentiable.allDifferentiableVariables
290+
if (name.isSimpleName(ctx.Id_allDifferentiableVariables))
291+
return getRequirement(KnownProtocolKind::Differentiable);
292+
288293
return nullptr;
289294
}
290295

@@ -449,7 +454,9 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
449454

450455
// SWIFT_ENABLE_TENSORFLOW
451456
// Differentiable.TangentVector
452-
if (name.isSimpleName(ctx.Id_TangentVector))
457+
// Differentiable.AllDifferentiableVariables
458+
if (name.isSimpleName(ctx.Id_TangentVector) ||
459+
name.isSimpleName(ctx.Id_AllDifferentiableVariables))
453460
return getRequirement(KnownProtocolKind::Differentiable);
454461

455462
// SWIFT_ENABLE_TENSORFLOW

stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,20 @@ public struct Tracked<T> {
6565
}
6666
private var handle: Box
6767

68-
@differentiable(vjp: _vjpInit where T : Differentiable, T == T.TangentVector)
68+
@differentiable(
69+
vjp: _vjpInit
70+
where T : Differentiable, T == T.AllDifferentiableVariables,
71+
T == T.TangentVector
72+
)
6973
public init(_ value: T) {
7074
self.handle = Box(value)
7175
}
7276

73-
@differentiable(vjp: _vjpValue where T : Differentiable, T == T.TangentVector)
77+
@differentiable(
78+
vjp: _vjpValue
79+
where T : Differentiable, T == T.AllDifferentiableVariables,
80+
T == T.TangentVector
81+
)
7482
public var value: T {
7583
get { handle.value }
7684
set { handle.value = newValue }
@@ -166,11 +174,17 @@ extension Tracked : Strideable where T : Strideable, T.Stride == T.Stride.Magnit
166174
}
167175

168176
// For now, `T` must be restricted to trivial types (like `Float` or `Tensor`).
169-
extension Tracked : Differentiable where T : Differentiable, T == T.TangentVector {
177+
extension Tracked : Differentiable
178+
where T : Differentiable, T == T.AllDifferentiableVariables,
179+
T == T.TangentVector
180+
{
181+
public typealias AllDifferentiableVariables = Tracked<T.AllDifferentiableVariables>
170182
public typealias TangentVector = Tracked<T.TangentVector>
171183
}
172184

173-
extension Tracked where T : Differentiable, T == T.TangentVector {
185+
extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
186+
T == T.TangentVector
187+
{
174188
@usableFromInline
175189
internal static func _vjpInit(_ value: T)
176190
-> (value: Self, pullback: (Self.TangentVector) -> (T.TangentVector)) {
@@ -183,7 +197,9 @@ extension Tracked where T : Differentiable, T == T.TangentVector {
183197
}
184198
}
185199

186-
extension Tracked where T : Differentiable, T == T.TangentVector {
200+
extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
201+
T == T.TangentVector
202+
{
187203
@usableFromInline
188204
@differentiating(+)
189205
internal static func _vjpAdd(lhs: Self, rhs: Self)
@@ -200,7 +216,7 @@ extension Tracked where T : Differentiable, T == T.TangentVector {
200216
}
201217

202218
extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
203-
T == T.TangentVector {
219+
T == T.AllDifferentiableVariables, T == T.TangentVector {
204220
@usableFromInline
205221
@differentiating(*)
206222
internal static func _vjpMultiply(lhs: Self, rhs: Self)
@@ -209,7 +225,8 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
209225
}
210226
}
211227

212-
extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector {
228+
extension Tracked where T : Differentiable & FloatingPoint,
229+
T == T.AllDifferentiableVariables, T == T.TangentVector {
213230
@usableFromInline
214231
@differentiating(/)
215232
internal static func _vjpDivide(lhs: Self, rhs: Self)

stdlib/public/core/Array.swift

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,26 @@ extension Array where Element : Differentiable {
19521952

19531953
public typealias TangentVector =
19541954
Array<Element.TangentVector>.DifferentiableView
1955+
public typealias AllDifferentiableVariables =
1956+
Array<Element.AllDifferentiableVariables>.DifferentiableView
1957+
1958+
public var allDifferentiableVariables: AllDifferentiableVariables {
1959+
get {
1960+
return AllDifferentiableVariables(
1961+
base.map { $0.allDifferentiableVariables })
1962+
}
1963+
set {
1964+
precondition(
1965+
base.count == newValue.base.count,
1966+
"cannot set Array.DifferentiableView.AllDifferentiableVariables " +
1967+
"with count \(base.count) to " +
1968+
"Array.DifferentiableView.AllDifferentiableVariables with " +
1969+
"different count \(newValue.base.count)")
1970+
for i in base.indices {
1971+
base[i].allDifferentiableVariables = newValue.base[i]
1972+
}
1973+
}
1974+
}
19551975

19561976
public mutating func move(along direction: TangentVector) {
19571977
precondition(
@@ -2046,13 +2066,27 @@ extension Array.DifferentiableView : AdditiveArithmetic
20462066
/// Makes `Array` differentiable as the product manifold of `Element`
20472067
/// multiplied with itself `count` times.
20482068
extension Array : Differentiable where Element : Differentiable {
2049-
// In an ideal world, `TangentVector` would be `[Element.TangentVector]`.
2050-
// Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for
2051-
// `TangentVector` because `Array` already has a static `+` method with
2052-
// different semantics from `AdditiveArithmetic.+`. So we use
2069+
// In an ideal world, `TangentVector`, `TangentVector`, and
2070+
// `AllDifferentiableVariables` would all be `Array`s. Unfortunately, we
2071+
// can't conform `Array` to `AdditiveArithmetic` for `TangentVector` and
2072+
// `TangentVector`, because `Array` already has a static `+` method with
2073+
// different semantics from `AdditiveArithmetic` `+`. So we use
20532074
// `Array.DifferentiableView` for all these associated types.
20542075
public typealias TangentVector =
20552076
Array<Element.TangentVector>.DifferentiableView
2077+
public typealias AllDifferentiableVariables =
2078+
Array<Element.AllDifferentiableVariables>.DifferentiableView
2079+
2080+
public var allDifferentiableVariables: AllDifferentiableVariables {
2081+
get {
2082+
return DifferentiableView(self).allDifferentiableVariables
2083+
}
2084+
set {
2085+
var view = DifferentiableView(self)
2086+
view.allDifferentiableVariables = newValue
2087+
self = view.base
2088+
}
2089+
}
20562090

20572091
public mutating func move(along direction: TangentVector) {
20582092
var view = DifferentiableView(self)

stdlib/public/core/AutoDiff.swift

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,24 +153,26 @@ public extension VectorProtocol where VectorSpaceScalar : SignedNumeric {
153153
/// tangent spaces are finite-dimensional.
154154
public protocol Differentiable {
155155
associatedtype TangentVector: Differentiable & AdditiveArithmetic
156-
where TangentVector.TangentVector == TangentVector
156+
where TangentVector.TangentVector == TangentVector,
157+
AllDifferentiableVariables.AllDifferentiableVariables ==
158+
AllDifferentiableVariables,
159+
AllDifferentiableVariables.TangentVector == TangentVector
160+
/// The type of all differentiable variables in this type.
161+
associatedtype AllDifferentiableVariables : Differentiable
162+
163+
/// All differentiable variables of this value.
164+
var allDifferentiableVariables: AllDifferentiableVariables { get set }
157165

158166
/// Moves `self` along the value space towards the given tangent vector. In
159167
/// Riemannian geometry (mathematics), this represents an exponential map.
160168
mutating func move(along direction: TangentVector)
161169

162-
@available(*, deprecated,
163-
message: "'AllDifferentiableVariables' is now equal to 'Self' and will be removed")
164-
typealias AllDifferentiableVariables = Self
165-
166170
@available(*, deprecated,
167171
message: "'CotangentVector' is now equal to 'TangentVector' and will be removed")
168172
typealias CotangentVector = TangentVector
169173
}
170174

171-
public extension Differentiable {
172-
@available(*, deprecated,
173-
message: "'allDifferentiableVariables' is now equal to 'self' and will be removed")
175+
public extension Differentiable where AllDifferentiableVariables == Self {
174176
var allDifferentiableVariables: AllDifferentiableVariables {
175177
get { return self }
176178
set { self = newValue }
@@ -721,14 +723,16 @@ internal protocol _AnyDerivativeBox {
721723
func _subtracting(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox
722724

723725
// `Differentiable` requirements.
726+
var _allDifferentiableVariables: _AnyDerivativeBox { get }
724727
mutating func _move(along direction: _AnyDerivativeBox)
725728

726729
/// The underlying base value, type-erased to `Any`.
727730
var _typeErasedBase: Any { get }
728731

729732
/// Returns the underlying value unboxed to the given type, if possible.
730733
func _unboxed<U>(to type: U.Type) -> U?
731-
where U : Differentiable, U.TangentVector == U
734+
where U : Differentiable, U.TangentVector == U,
735+
U.AllDifferentiableVariables == U
732736
}
733737

734738
extension _AnyDerivativeBox {
@@ -750,7 +754,8 @@ internal func _derivativeTypeMismatch(
750754
}
751755

752756
internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
753-
where T : Differentiable, T.TangentVector == T
757+
where T : Differentiable, T.TangentVector == T,
758+
T.AllDifferentiableVariables == T
754759
{
755760
/// The underlying base value.
756761
var _base: T
@@ -765,7 +770,8 @@ internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
765770
}
766771

767772
func _unboxed<U>(to type: U.Type) -> U?
768-
where U : Differentiable, U.TangentVector == U
773+
where U : Differentiable, U.TangentVector == U,
774+
U.AllDifferentiableVariables == U
769775
{
770776
return (self as? _ConcreteDerivativeBox<U>)?._base
771777
}
@@ -818,6 +824,10 @@ internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
818824

819825
// `Differentiable` requirements.
820826

827+
var _allDifferentiableVariables: _AnyDerivativeBox {
828+
return _ConcreteDerivativeBox(_base.allDifferentiableVariables)
829+
}
830+
821831
mutating func _move(along direction: _AnyDerivativeBox) {
822832
if direction._isOpaqueZero() {
823833
return
@@ -851,19 +861,24 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
851861

852862
/// Creates a type-erased derivative from the given derivative.
853863
@differentiable(vjp: _vjpInit(_:))
854-
public init<T>(_ base: T) where T : Differentiable, T.TangentVector == T {
864+
public init<T>(_ base: T)
865+
where T : Differentiable, T.TangentVector == T,
866+
T.AllDifferentiableVariables == T
867+
{
855868
self._box = _ConcreteDerivativeBox<T>(base)
856869
}
857870

858871
@usableFromInline internal static func _vjpInit<T>(
859872
_ base: T
860873
) -> (AnyDerivative, (AnyDerivative) -> T.TangentVector)
861-
where T : Differentiable, T.TangentVector == T
874+
where T : Differentiable, T.TangentVector == T,
875+
T.AllDifferentiableVariables == T
862876
{
863877
return (AnyDerivative(base), { v in v.base as! T.TangentVector })
864878
}
865879

866880
public typealias TangentVector = AnyDerivative
881+
public typealias AllDifferentiableVariables = AnyDerivative
867882

868883
// `Equatable` requirements (implied by `AdditiveArithmetic`).
869884
public static func == (lhs: AnyDerivative, rhs: AnyDerivative) -> Bool {
@@ -914,6 +929,10 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
914929
}
915930

916931
// `Differentiable` requirements.
932+
public var allDifferentiableVariables: AllDifferentiableVariables {
933+
get { return AnyDerivative(_box: _box._allDifferentiableVariables) }
934+
// set { _box._allDifferentiableVariables = newValue._box }
935+
}
917936
public mutating func move(along direction: TangentVector) {
918937
if _box._isOpaqueZero() {
919938
_box = direction._box

stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,6 +1906,7 @@ extension ${Self} : VectorProtocol {
19061906

19071907
extension ${Self} : Differentiable {
19081908
public typealias TangentVector = ${Self}
1909+
public typealias AllDifferentiableVariables = ${Self}
19091910

19101911
public mutating func move(along direction: TangentVector) {
19111912
self += direction

stdlib/public/core/SIMDVectorTypes.swift.gyb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ extension SIMD${n} : Differentiable
196196
where Scalar : Differentiable & BinaryFloatingPoint,
197197
Scalar.TangentVector : BinaryFloatingPoint {
198198
public typealias TangentVector = SIMD${n}
199+
public typealias AllDifferentiableVariables = SIMD${n}
200+
public func tangentVector(from cotangent: TangentVector) -> TangentVector {
201+
return cotangent
202+
}
199203
}
200204

201205
extension SIMD${n}

test/AutoDiff/control_flow.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,9 +492,9 @@ ControlFlowTests.test("Enums") {
492492
return input * w1
493493
}
494494
}
495-
expectEqual((Dense.TangentVector(w1: 10), 20),
495+
expectEqual((Dense.AllDifferentiableVariables(w1: 10), 20),
496496
Dense(w1: 4, w2: 5).gradient(at: 2, in: { dense, x in dense(x) }))
497-
expectEqual((Dense.TangentVector(w1: 2), 4),
497+
expectEqual((Dense.AllDifferentiableVariables(w1: 2), 4),
498498
Dense(w1: 4, w2: nil).gradient(at: 2, in: { dense, x in dense(x) }))
499499

500500
indirect enum Indirect {

test/AutoDiff/derived_conformances.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ DerivedConformanceTests.test("MemberwiseInitializers") {
2626
@noDerivative let constant2 = Double(1)
2727
var x = Float(1)
2828
}
29+
expectEqual(HasNoDerivativeConstant.AllDifferentiableVariables(x: 0),
30+
HasNoDerivativeConstant.AllDifferentiableVariables.zero)
2931
expectEqual(HasNoDerivativeConstant.TangentVector(x: 0),
3032
HasNoDerivativeConstant.TangentVector.zero)
3133
}

test/AutoDiff/derived_differentiable.swift

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ public struct Foo : Differentiable {
1212
// CHECK-AST: @differentiable
1313
// CHECK-AST: public var a: Float
1414
// CHECK-AST: internal init(a: Float)
15-
// CHECK-AST: public struct TangentVector
16-
// CHECK-AST: public typealias TangentVector = Foo.TangentVector
15+
// CHECK-AST: public struct AllDifferentiableVariables
16+
// CHECK-AST: public typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables
17+
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
18+
// CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables
1719

1820
// CHECK-SILGEN-LABEL: // Foo.a.getter
1921
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s22derived_differentiable3FooV1aSfvg : $@convention(method) (Foo) -> Float
@@ -31,6 +33,7 @@ let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in
3133
// CHECK-AST: internal var dummy: PointwiseMultiplicativeDummy
3234
// CHECK-AST: internal init(a: Float, dummy: PointwiseMultiplicativeDummy)
3335
// CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf
36+
// CHECK-AST: internal typealias AllDifferentiableVariables = AdditiveTangentIsSelf
3437

3538
struct TestNoDerivative : Differentiable {
3639
var w: Float
@@ -41,8 +44,10 @@ struct TestNoDerivative : Differentiable {
4144
// CHECK-AST: var w: Float
4245
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
4346
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
44-
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol
45-
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.TangentVector
47+
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol
48+
// CHECK-AST: internal typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables
49+
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
50+
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables
4651

4752
struct TestPointwiseMultiplicative : Differentiable {
4853
var w: PointwiseMultiplicativeDummy
@@ -53,8 +58,10 @@ struct TestPointwiseMultiplicative : Differentiable {
5358
// CHECK-AST: var w: PointwiseMultiplicativeDummy
5459
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: PointwiseMultiplicativeDummy
5560
// CHECK-AST: internal init(w: PointwiseMultiplicativeDummy, technicallyDifferentiable: PointwiseMultiplicativeDummy)
56-
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative
57-
// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.TangentVector
61+
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, PointwiseMultiplicative
62+
// CHECK-AST: internal typealias AllDifferentiableVariables = TestPointwiseMultiplicative.AllDifferentiableVariables
63+
// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.AllDifferentiableVariables
64+
// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.AllDifferentiableVariables
5865

5966

6067
struct TestKeyPathIterable : Differentiable, KeyPathIterable {
@@ -66,8 +73,10 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable {
6673
// CHECK-AST: var w: Float
6774
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
6875
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
69-
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol, KeyPathIterable
70-
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.TangentVector
76+
// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, ElementaryFunctions, VectorProtocol
77+
// CHECK-AST: internal typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables
78+
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
79+
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables
7180

7281
struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic {
7382
var x: T.TangentVector
@@ -77,6 +86,7 @@ struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic
7786
// CHECK-AST: internal var x: T.TangentVector
7887
// CHECK-AST: internal init(x: T.TangentVector)
7988
// CHECK-AST: internal typealias TangentVector = GenericTanMember<T>
89+
// CHECK-AST: internal typealias AllDifferentiableVariables = GenericTanMember<T>
8090
// CHECK-AST: internal static var zero: GenericTanMember<T> { get }
8191
// CHECK-AST: internal static func + (lhs: GenericTanMember<T>, rhs: GenericTanMember<T>) -> GenericTanMember<T>
8292
// CHECK-AST: internal static func - (lhs: GenericTanMember<T>, rhs: GenericTanMember<T>) -> GenericTanMember<T>

0 commit comments

Comments
 (0)