Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 33b178f

Browse files
committed
Add vjpInit test.
1 parent 1b4c33a commit 33b178f

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

Sources/third_party/Experimental/Complex.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@
4444
/// different results when working with special values.
4545

4646
struct Complex<T: FloatingPoint> {
47-
@differentiable
4847
var real: T
49-
@differentiable
5048
var imaginary: T
5149

5250
@differentiable(vjp: _vjpInit where T: Differentiable, T.TangentVector == T)

Tests/ExperimentalTests/ComplexTests.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ final class ComplexTests: XCTestCase {
151151
expected = Complex<Float>(real: 2, imaginary: -5)
152152
XCTAssertEqual(expected, input.subtracting(imaginary: 1))
153153
}
154+
155+
func testVjpInit() {
156+
let pb = pullback(at: 4, -3) { r, i in
157+
return Complex<Float>(real: r, imaginary: i)
158+
}
159+
XCTAssertEqual((-1, 2), pb(Complex<Float>(real: -1, imaginary: 2)))
160+
}
154161

155162
func testVjpAdd() {
156163
let pb: (Complex<Float>) -> Complex<Float> =
@@ -253,7 +260,6 @@ final class ComplexTests: XCTestCase {
253260

254261
func testJvpDotProduct() {
255262
struct ComplexVector : Differentiable & AdditiveArithmetic {
256-
@differentiable
257263
var w: Complex<Float>
258264
var x: Complex<Float>
259265
var y: Complex<Float>
@@ -329,6 +335,7 @@ final class ComplexTests: XCTestCase {
329335
("testComplexConjugate", testComplexConjugate),
330336
("testAdding", testAdding),
331337
("testSubtracting", testSubtracting),
338+
("testVjpInit", testVjpInit),
332339
("testVjpAdd", testVjpAdd),
333340
("testVjpSubtract", testVjpSubtract),
334341
("testVjpMultiply", testVjpMultiply),

0 commit comments

Comments
 (0)