4
4
import _Differentiation
5
5
import StdlibUnittest
6
6
7
- var AnyDerivativeTests = TestSuite ( " AnyDerivative " )
7
+ var TypeErasureTests = TestSuite ( " DifferentiableTypeErasure " )
8
8
9
- struct Vector : Differentiable {
9
+ struct Vector : Differentiable , Equatable {
10
10
var x , y : Float
11
11
}
12
- struct Generic < T: Differentiable > : Differentiable {
12
+ struct Generic < T: Differentiable & Equatable > : Differentiable , Equatable {
13
13
var x : T
14
14
}
15
15
@@ -22,28 +22,44 @@ extension AnyDerivative {
22
22
}
23
23
}
24
24
25
- AnyDerivativeTests . test ( " Vector " ) {
26
- var tan = AnyDerivative ( Vector . TangentVector ( x: 1 , y: 1 ) )
27
- tan += tan
28
- expectEqual ( AnyDerivative ( Vector . TangentVector ( x: 2 , y: 2 ) ) , tan)
29
- expectEqual ( AnyDerivative ( Vector . TangentVector ( x: 4 , y: 4 ) ) , tan + tan)
30
- expectEqual ( AnyDerivative ( Vector . TangentVector ( x: 0 , y: 0 ) ) , tan - tan)
31
- expectEqual ( AnyDerivative ( Vector . TangentVector ( x: 4 , y: 4 ) ) , tan. moved ( along: tan) )
32
- expectEqual ( AnyDerivative ( Vector . TangentVector ( x: 2 , y: 2 ) ) , tan)
25
+ TypeErasureTests . test ( " AnyDifferentiable operations " ) {
26
+ do {
27
+ var any = AnyDifferentiable ( Vector ( x: 1 , y: 1 ) )
28
+ let tan = AnyDerivative ( Vector . TangentVector ( x: 1 , y: 1 ) )
29
+ any. move ( along: tan)
30
+ expectEqual ( Vector ( x: 2 , y: 2 ) , any. base as? Vector )
31
+ }
32
+
33
+ do {
34
+ var any = AnyDifferentiable ( Generic < Float > ( x: 1 ) )
35
+ let tan = AnyDerivative ( Generic< Float> . TangentVector( x: 1 ) )
36
+ any. move ( along: tan)
37
+ expectEqual ( Generic < Float > ( x: 2 ) , any. base as? Generic < Float > )
38
+ }
33
39
}
34
40
35
- AnyDerivativeTests . test ( " Generic " ) {
36
- var tan = AnyDerivative ( Generic< Float> . TangentVector( x: 1 ) )
37
- let cotan = AnyDerivative ( Generic< Float> . TangentVector( x: 1 ) )
38
- tan += tan
39
- expectEqual ( AnyDerivative ( Generic< Float> . TangentVector( x: 2 ) ) , tan)
40
- expectEqual ( AnyDerivative ( Generic< Float> . TangentVector( x: 4 ) ) , tan + tan)
41
- expectEqual ( AnyDerivative ( Generic< Float> . TangentVector( x: 0 ) ) , tan - tan)
42
- expectEqual ( AnyDerivative ( Generic< Float> . TangentVector( x: 4 ) ) , tan. moved ( along: tan) )
43
- expectEqual ( AnyDerivative ( Generic< Float> . TangentVector( x: 1 ) ) , cotan)
41
+ TypeErasureTests . test ( " AnyDerivative operations " ) {
42
+ do {
43
+ var tan = AnyDerivative ( Vector . TangentVector ( x: 1 , y: 1 ) )
44
+ tan += tan
45
+ expectEqual ( AnyDerivative ( Vector . TangentVector ( x: 2 , y: 2 ) ) , tan)
46
+ expectEqual ( AnyDerivative ( Vector . TangentVector ( x: 4 , y: 4 ) ) , tan + tan)
47
+ expectEqual ( AnyDerivative ( Vector . TangentVector ( x: 0 , y: 0 ) ) , tan - tan)
48
+ expectEqual ( AnyDerivative ( Vector . TangentVector ( x: 4 , y: 4 ) ) , tan. moved ( along: tan) )
49
+ expectEqual ( AnyDerivative ( Vector . TangentVector ( x: 2 , y: 2 ) ) , tan)
50
+ }
51
+
52
+ do {
53
+ var tan = AnyDerivative ( Generic< Float> . TangentVector( x: 1 ) )
54
+ tan += tan
55
+ expectEqual ( AnyDerivative ( Generic< Float> . TangentVector( x: 2 ) ) , tan)
56
+ expectEqual ( AnyDerivative ( Generic< Float> . TangentVector( x: 4 ) ) , tan + tan)
57
+ expectEqual ( AnyDerivative ( Generic< Float> . TangentVector( x: 0 ) ) , tan - tan)
58
+ expectEqual ( AnyDerivative ( Generic< Float> . TangentVector( x: 4 ) ) , tan. moved ( along: tan) )
59
+ }
44
60
}
45
61
46
- AnyDerivativeTests . test ( " Zero " ) {
62
+ TypeErasureTests . test ( " AnyDerivative.zero " ) {
47
63
var zero = AnyDerivative . zero
48
64
zero += zero
49
65
zero -= zero
@@ -66,7 +82,17 @@ AnyDerivativeTests.test("Zero") {
66
82
expectEqual ( tan, tan)
67
83
}
68
84
69
- AnyDerivativeTests . test ( " Casting " ) {
85
+ TypeErasureTests . test ( " AnyDifferentiable casting " ) {
86
+ let any = AnyDifferentiable ( Vector ( x: 1 , y: 1 ) )
87
+ expectEqual ( Vector ( x: 1 , y: 1 ) , any. base as? Vector )
88
+
89
+ let genericAny = AnyDifferentiable ( Generic < Float > ( x: 1 ) )
90
+ expectEqual ( Generic < Float > ( x: 1 ) ,
91
+ genericAny. base as? Generic < Float > )
92
+ expectEqual ( nil , genericAny. base as? Generic < Double > )
93
+ }
94
+
95
+ TypeErasureTests . test ( " AnyDerivative casting " ) {
70
96
let tan = AnyDerivative ( Vector . TangentVector ( x: 1 , y: 1 ) )
71
97
expectEqual ( Vector . TangentVector ( x: 1 , y: 1 ) , tan. base as? Vector . TangentVector )
72
98
@@ -81,7 +107,34 @@ AnyDerivativeTests.test("Casting") {
81
107
expectEqual ( nil , zero. base as? Generic < Float > . TangentVector )
82
108
}
83
109
84
- AnyDerivativeTests . test ( " Derivatives " ) {
110
+ TypeErasureTests . test ( " AnyDifferentiable differentiation " ) {
111
+ // Test `AnyDifferentiable` initializer.
112
+ do {
113
+ let x : Float = 3
114
+ let v = AnyDerivative ( Float ( 2 ) )
115
+ let 𝛁x = pullback( at: x, in: { AnyDifferentiable ( $0) } ) ( v)
116
+ let expectedVJP : Float = 2
117
+ expectEqual ( expectedVJP, 𝛁x)
118
+ }
119
+
120
+ do {
121
+ let x = Vector ( x: 4 , y: 5 )
122
+ let v = AnyDerivative ( Vector . TangentVector ( x: 2 , y: 2 ) )
123
+ let 𝛁x = pullback( at: x, in: { AnyDifferentiable ( $0) } ) ( v)
124
+ let expectedVJP = Vector . TangentVector ( x: 2 , y: 2 )
125
+ expectEqual ( expectedVJP, 𝛁x)
126
+ }
127
+
128
+ do {
129
+ let x = Generic < Double > ( x: 4 )
130
+ let v = AnyDerivative ( Generic< Double> . TangentVector( x: 2 ) )
131
+ let 𝛁x = pullback( at: x, in: { AnyDifferentiable ( $0) } ) ( v)
132
+ let expectedVJP = Generic< Double> . TangentVector( x: 2 )
133
+ expectEqual ( expectedVJP, 𝛁x)
134
+ }
135
+ }
136
+
137
+ TypeErasureTests . test ( " AnyDerivative differentiation " ) {
85
138
// Test `AnyDerivative` operations.
86
139
func tripleSum( _ x: AnyDerivative , _ y: AnyDerivative ) -> AnyDerivative {
87
140
let sum = x + y
0 commit comments