Skip to content

Commit 3ef8360

Browse files
authored
[AutoDiff] Add AnyDifferentiable. (#30928)
Add `AnyDifferentiable`, a type-erased wrapper for `Differentiable`-conforming types.
1 parent 0366e61 commit 3ef8360

File tree

3 files changed

+187
-38
lines changed

3 files changed

+187
-38
lines changed

stdlib/public/Differentiation/AnyDerivative.swift renamed to stdlib/public/Differentiation/AnyDifferentiable.swift

Lines changed: 110 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,105 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212
//
13-
// This file defines `AnyDerivative`, a type-erased wrapper for
14-
// `Differentiable.TangentVector` associated type implementations.
13+
// This file defines type-erased wrappers for `Differentiable`-conforming types
14+
// and `Differentiable.TangentVector` associated type implementations.
1515
//
1616
//===----------------------------------------------------------------------===//
1717

1818
import Swift
1919

20+
//===----------------------------------------------------------------------===//
21+
// `AnyDifferentiable`
22+
//===----------------------------------------------------------------------===//
23+
24+
internal protocol _AnyDifferentiableBox {
25+
// `Differentiable` requirements.
26+
mutating func _move(along direction: AnyDerivative)
27+
28+
/// The underlying base value, type-erased to `Any`.
29+
var _typeErasedBase: Any { get }
30+
31+
/// Returns the underlying value unboxed to the given type, if possible.
32+
func _unboxed<U: Differentiable>(to type: U.Type) -> U?
33+
}
34+
35+
internal struct _ConcreteDifferentiableBox<T: Differentiable>: _AnyDifferentiableBox
36+
{
37+
/// The underlying base value.
38+
var _base: T
39+
40+
init(_ base: T) {
41+
self._base = base
42+
}
43+
44+
/// The underlying base value, type-erased to `Any`.
45+
var _typeErasedBase: Any {
46+
return _base
47+
}
48+
49+
func _unboxed<U: Differentiable>(to type: U.Type) -> U? {
50+
return (self as? _ConcreteDifferentiableBox<U>)?._base
51+
}
52+
53+
mutating func _move(along direction: AnyDerivative) {
54+
guard
55+
let directionBase =
56+
direction.base as? T.TangentVector
57+
else {
58+
_derivativeTypeMismatch(T.self, type(of: direction.base))
59+
}
60+
_base.move(along: directionBase)
61+
}
62+
}
63+
64+
public struct AnyDifferentiable: Differentiable {
65+
internal var _box: _AnyDifferentiableBox
66+
67+
internal init(_box: _AnyDifferentiableBox) {
68+
self._box = _box
69+
}
70+
71+
/// The underlying base value.
72+
public var base: Any {
73+
return _box._typeErasedBase
74+
}
75+
76+
/// Creates a type-erased derivative from the given derivative.
77+
@differentiable
78+
public init<T: Differentiable>(_ base: T) {
79+
self._box = _ConcreteDifferentiableBox<T>(base)
80+
}
81+
82+
@inlinable
83+
@derivative(of: init)
84+
internal static func _vjpInit<T: Differentiable>(
85+
_ base: T
86+
) -> (value: AnyDifferentiable, pullback: (AnyDerivative) -> T.TangentVector)
87+
{
88+
return (AnyDifferentiable(base), { v in v.base as! T.TangentVector })
89+
}
90+
91+
@inlinable
92+
@derivative(of: init)
93+
internal static func _jvpInit<T: Differentiable>(
94+
_ base: T
95+
) -> (
96+
value: AnyDifferentiable, differential: (T.TangentVector) -> AnyDerivative
97+
) {
98+
return (AnyDifferentiable(base), { dbase in AnyDerivative(dbase) })
99+
}
100+
101+
public typealias TangentVector = AnyDerivative
102+
103+
public mutating func move(along direction: TangentVector) {
104+
_box._move(along: direction)
105+
}
106+
}
107+
108+
//===----------------------------------------------------------------------===//
109+
// `AnyDerivative`
110+
//===----------------------------------------------------------------------===//
111+
20112
@usableFromInline
21113
internal protocol _AnyDerivativeBox {
22114
// `Equatable` requirements (implied by `AdditiveArithmetic`).
@@ -47,18 +139,6 @@ extension _AnyDerivativeBox {
47139
}
48140
}
49141

50-
@inline(never)
51-
@usableFromInline
52-
internal func _derivativeTypeMismatch(
53-
_ x: Any.Type, _ y: Any.Type, file: StaticString = #file, line: UInt = #line
54-
) -> Never {
55-
preconditionFailure(
56-
"""
57-
Derivative type mismatch: \
58-
\(String(reflecting: x)) and \(String(reflecting: y))
59-
""", file: file, line: line)
60-
}
61-
62142
@frozen
63143
@usableFromInline
64144
internal struct _ConcreteDerivativeBox<T>: _AnyDerivativeBox
@@ -290,3 +370,19 @@ public struct AnyDerivative: Differentiable & AdditiveArithmetic {
290370
_box._move(along: direction._box)
291371
}
292372
}
373+
374+
//===----------------------------------------------------------------------===//
375+
// Helpers
376+
//===----------------------------------------------------------------------===//
377+
378+
@inline(never)
379+
@usableFromInline
380+
internal func _derivativeTypeMismatch(
381+
_ x: Any.Type, _ y: Any.Type, file: StaticString = #file, line: UInt = #line
382+
) -> Never {
383+
preconditionFailure(
384+
"""
385+
Derivative type mismatch: \
386+
\(String(reflecting: x)) and \(String(reflecting: y))
387+
""", file: file, line: line)
388+
}

stdlib/public/Differentiation/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE
1414
Differentiable.swift
1515
DifferentialOperators.swift
1616
DifferentiationUtilities.swift
17-
AnyDerivative.swift
17+
AnyDifferentiable.swift
1818
ArrayDifferentiation.swift
1919

2020
GYB_SOURCES

test/AutoDiff/stdlib/anyderivative.swift renamed to test/AutoDiff/stdlib/anydifferentiable.swift

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import _Differentiation
55
import StdlibUnittest
66

7-
var AnyDerivativeTests = TestSuite("AnyDerivative")
7+
var TypeErasureTests = TestSuite("DifferentiableTypeErasure")
88

9-
struct Vector: Differentiable {
9+
struct Vector: Differentiable, Equatable {
1010
var x, y: Float
1111
}
12-
struct Generic<T: Differentiable>: Differentiable {
12+
struct Generic<T: Differentiable & Equatable>: Differentiable, Equatable {
1313
var x: T
1414
}
1515

@@ -22,28 +22,44 @@ extension AnyDerivative {
2222
}
2323
}
2424

25-
AnyDerivativeTests.test("Vector") {
26-
var tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
27-
tan += tan
28-
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
29-
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan + tan)
30-
expectEqual(AnyDerivative(Vector.TangentVector(x: 0, y: 0)), tan - tan)
31-
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan.moved(along: tan))
32-
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
25+
TypeErasureTests.test("AnyDifferentiable operations") {
26+
do {
27+
var any = AnyDifferentiable(Vector(x: 1, y: 1))
28+
let tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
29+
any.move(along: tan)
30+
expectEqual(Vector(x: 2, y: 2), any.base as? Vector)
31+
}
32+
33+
do {
34+
var any = AnyDifferentiable(Generic<Float>(x: 1))
35+
let tan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
36+
any.move(along: tan)
37+
expectEqual(Generic<Float>(x: 2), any.base as? Generic<Float>)
38+
}
3339
}
3440

35-
AnyDerivativeTests.test("Generic") {
36-
var tan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
37-
let cotan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
38-
tan += tan
39-
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 2)), tan)
40-
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan + tan)
41-
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 0)), tan - tan)
42-
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan.moved(along: tan))
43-
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 1)), cotan)
41+
TypeErasureTests.test("AnyDerivative operations") {
42+
do {
43+
var tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
44+
tan += tan
45+
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
46+
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan + tan)
47+
expectEqual(AnyDerivative(Vector.TangentVector(x: 0, y: 0)), tan - tan)
48+
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan.moved(along: tan))
49+
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
50+
}
51+
52+
do {
53+
var tan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
54+
tan += tan
55+
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 2)), tan)
56+
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan + tan)
57+
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 0)), tan - tan)
58+
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan.moved(along: tan))
59+
}
4460
}
4561

46-
AnyDerivativeTests.test("Zero") {
62+
TypeErasureTests.test("AnyDerivative.zero") {
4763
var zero = AnyDerivative.zero
4864
zero += zero
4965
zero -= zero
@@ -66,7 +82,17 @@ AnyDerivativeTests.test("Zero") {
6682
expectEqual(tan, tan)
6783
}
6884

69-
AnyDerivativeTests.test("Casting") {
85+
TypeErasureTests.test("AnyDifferentiable casting") {
86+
let any = AnyDifferentiable(Vector(x: 1, y: 1))
87+
expectEqual(Vector(x: 1, y: 1), any.base as? Vector)
88+
89+
let genericAny = AnyDifferentiable(Generic<Float>(x: 1))
90+
expectEqual(Generic<Float>(x: 1),
91+
genericAny.base as? Generic<Float>)
92+
expectEqual(nil, genericAny.base as? Generic<Double>)
93+
}
94+
95+
TypeErasureTests.test("AnyDerivative casting") {
7096
let tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
7197
expectEqual(Vector.TangentVector(x: 1, y: 1), tan.base as? Vector.TangentVector)
7298

@@ -81,7 +107,34 @@ AnyDerivativeTests.test("Casting") {
81107
expectEqual(nil, zero.base as? Generic<Float>.TangentVector)
82108
}
83109

84-
AnyDerivativeTests.test("Derivatives") {
110+
TypeErasureTests.test("AnyDifferentiable differentiation") {
111+
// Test `AnyDifferentiable` initializer.
112+
do {
113+
let x: Float = 3
114+
let v = AnyDerivative(Float(2))
115+
let 𝛁x = pullback(at: x, in: { AnyDifferentiable($0) })(v)
116+
let expectedVJP: Float = 2
117+
expectEqual(expectedVJP, 𝛁x)
118+
}
119+
120+
do {
121+
let x = Vector(x: 4, y: 5)
122+
let v = AnyDerivative(Vector.TangentVector(x: 2, y: 2))
123+
let 𝛁x = pullback(at: x, in: { AnyDifferentiable($0) })(v)
124+
let expectedVJP = Vector.TangentVector(x: 2, y: 2)
125+
expectEqual(expectedVJP, 𝛁x)
126+
}
127+
128+
do {
129+
let x = Generic<Double>(x: 4)
130+
let v = AnyDerivative(Generic<Double>.TangentVector(x: 2))
131+
let 𝛁x = pullback(at: x, in: { AnyDifferentiable($0) })(v)
132+
let expectedVJP = Generic<Double>.TangentVector(x: 2)
133+
expectEqual(expectedVJP, 𝛁x)
134+
}
135+
}
136+
137+
TypeErasureTests.test("AnyDerivative differentiation") {
85138
// Test `AnyDerivative` operations.
86139
func tripleSum(_ x: AnyDerivative, _ y: AnyDerivative) -> AnyDerivative {
87140
let sum = x + y

0 commit comments

Comments
 (0)