@@ -153,18 +153,28 @@ final class ComplexTests: XCTestCase {
153
153
}
154
154
155
155
func testVjpInit( ) {
156
- let pb = pullback ( at: 4 , - 3 ) { r, i in
156
+ var pb = pullback ( at: 4 , - 3 ) { r, i in
157
157
return Complex < Float > ( real: r, imaginary: i)
158
158
}
159
- XCTAssertEqual ( ( - 1 , 2 ) , pb ( Complex < Float > ( real: - 1 , imaginary: 2 ) ) )
159
+ var tanTuple = pb ( Complex < Float > ( real: - 1 , imaginary: 2 ) )
160
+ XCTAssertEqual ( - 1 , tanTuple. 0 )
161
+ XCTAssertEqual ( 2 , tanTuple. 1 )
162
+
163
+ pb = pullback ( at: 4 , - 3 ) { r, i in
164
+ return Complex < Float > ( real: r * r, imaginary: i + i)
165
+ }
166
+ tanTuple = pb ( Complex < Float > ( real: - 1 , imaginary: 1 ) )
167
+ XCTAssertEqual ( - 8 , tanTuple. 0 )
168
+ XCTAssertEqual ( 2 , tanTuple. 1 )
160
169
}
161
170
162
171
func testVjpAdd( ) {
163
172
let pb : ( Complex < Float > ) -> Complex < Float > =
164
173
pullback ( at: Complex < Float > ( real: 2 , imaginary: 3 ) ) { x in
165
174
return x + Complex < Float > ( real: 5 , imaginary: 6 )
166
175
}
167
- XCTAssertEqual ( pb ( Complex ( real: 1 , imaginary: 1 ) ) , Complex < Float > ( real: 1 , imaginary: 1 ) )
176
+ XCTAssertEqual ( pb ( Complex ( real: 1 , imaginary: 1 ) ) ,
177
+ Complex < Float > ( real: 1 , imaginary: 1 ) )
168
178
}
169
179
170
180
func testVjpSubtract( ) {
@@ -316,7 +326,7 @@ final class ComplexTests: XCTestCase {
316
326
}
317
327
318
328
XCTAssertEqual ( - 2 , result)
319
- XCTAssertEqual ( Complex ( real: 1 , imaginary: 1 ) , pbComplex ( 1 ) )
329
+ XCTAssertEqual ( Complex ( real: 1 , imaginary: 0 ) , pbComplex ( 1 ) )
320
330
}
321
331
322
332
static var allTests = [
0 commit comments