@@ -75,6 +75,7 @@ AnyDerivativeTests.test("Casting") {
75
75
}
76
76
77
77
AnyDerivativeTests . test ( " Derivatives " ) {
78
+ // Test `AnyDerivative` operations.
78
79
func tripleSum( _ x: AnyDerivative , _ y: AnyDerivative ) -> AnyDerivative {
79
80
let sum = x + y
80
81
return sum + sum + sum
@@ -112,6 +113,42 @@ AnyDerivativeTests.test("Derivatives") {
112
113
expectEqual ( expectedVJP, 𝛁x. base as? Generic < Double > . CotangentVector )
113
114
expectEqual ( expectedVJP, 𝛁y. base as? Generic < Double > . CotangentVector )
114
115
}
116
+
117
+ // Test `AnyDerivative` initializer.
118
+ func typeErased< T> ( _ x: T ) -> AnyDerivative
119
+ where T : Differentiable , T. TangentVector == T ,
120
+ T. AllDifferentiableVariables == T ,
121
+ // NOTE: The requirement below should be defined on `Differentiable`.
122
+ // But it causes a crash due to generic signature minimization bug.
123
+ T. CotangentVector == T . CotangentVector . AllDifferentiableVariables
124
+ {
125
+ let any = AnyDerivative ( x)
126
+ return any + any
127
+ }
128
+
129
+ do {
130
+ let x : Float = 3
131
+ let v = AnyDerivative ( Float ( 1 ) )
132
+ let 𝛁x = pullback( at: x, in: { x in typeErased ( x) } ) ( v)
133
+ let expectedVJP : Float = 2
134
+ expectEqual ( expectedVJP, 𝛁x)
135
+ }
136
+
137
+ do {
138
+ let x = Vector . TangentVector ( x: 4 , y: 5 )
139
+ let v = AnyDerivative ( Vector . CotangentVector ( x: 1 , y: 1 ) )
140
+ let 𝛁x = pullback( at: x, in: { x in typeErased ( x) } ) ( v)
141
+ let expectedVJP = Vector . CotangentVector ( x: 2 , y: 2 )
142
+ expectEqual ( expectedVJP, 𝛁x)
143
+ }
144
+
145
+ do {
146
+ let x = Generic< Double> . TangentVector( x: 4 )
147
+ let v = AnyDerivative ( Generic< Double> . CotangentVector( x: 1 ) )
148
+ let 𝛁x = pullback( at: x, in: { x in typeErased ( x) } ) ( v)
149
+ let expectedVJP = Generic< Double> . CotangentVector( x: 2 )
150
+ expectEqual ( expectedVJP, 𝛁x)
151
+ }
115
152
}
116
153
117
154
runAllTests ( )
0 commit comments