Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit fdd1c50

Browse files
committed
Move around tests and file for complex.
1 parent 23870cb commit fdd1c50

File tree

3 files changed

+255
-15
lines changed

3 files changed

+255
-15
lines changed

Sources/TensorFlow/Core/Complex.swift renamed to Sources/DeepLearning/Complex.swift

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ public struct Complex<T : FloatingPoint> {
22
public var real: T
33
public var imaginary: T
44

5-
// TODO: Make differentiable, crashing right now
5+
@differentiable(vjp: _vjpInit where T : Differentiable, T.TangentVector == T)
66
public init(real: T = 0, imaginary: T = 0) {
77
self.real = real
88
self.imaginary = imaginary
@@ -15,28 +15,23 @@ extension Complex : Differentiable where T : Differentiable {
1515
}
1616

1717
extension Complex {
18-
1918
public static var i: Complex {
2019
return Complex(real: 0, imaginary: 1)
2120
}
2221

23-
2422
public var isFinite: Bool {
2523
return real.isFinite && imaginary.isFinite
2624
}
2725

28-
2926
public var isInfinite: Bool {
3027
return real.isInfinite || imaginary.isInfinite
3128
}
3229

33-
3430
public var isNaN: Bool {
3531
return (real.isNaN && !imaginary.isInfinite) ||
3632
(imaginary.isNaN && !real.isInfinite)
3733
}
3834

39-
4035
public var isZero: Bool {
4136
return real.isZero && imaginary.isZero
4237
}
@@ -79,21 +74,18 @@ extension Complex : AdditiveArithmetic {
7974
return temp
8075
}
8176

82-
8377
public static func += (lhs: inout Complex, rhs: Complex) {
8478
lhs.real += rhs.real
8579
lhs.imaginary += rhs.imaginary
8680
}
8781

88-
8982
@differentiable(vjp: _vjpSubtract(lhs:rhs:) where T : Differentiable)
9083
public static func - (lhs: Complex, rhs: Complex) -> Complex {
9184
var temp = lhs
9285
temp -= rhs
9386
return temp
9487
}
9588

96-
9789
public static func -= (lhs: inout Complex, rhs: Complex) {
9890
lhs.real -= rhs.real
9991
lhs.imaginary -= rhs.imaginary
@@ -149,12 +141,10 @@ extension Complex : Numeric {
149141
return Complex(real: x, imaginary: y)
150142
}
151143

152-
153144
public static func *= (lhs: inout Complex, rhs: Complex) {
154145
lhs = lhs * rhs
155146
}
156147

157-
158148
public var magnitude: T {
159149
var x = abs(real)
160150
var y = abs(imaginary)
@@ -174,7 +164,6 @@ extension Complex : SignedNumeric {
174164
return Complex(real: -operand.real, imaginary: -operand.imaginary)
175165
}
176166

177-
178167
public mutating func negate() {
179168
real.negate()
180169
imaginary.negate()
@@ -246,23 +235,20 @@ extension Complex {
246235
return c
247236
}
248237

249-
250238
@differentiable(vjp: _vjpSubtracting(real:) where T : Differentiable, T.TangentVector == T)
251239
public func subtracting(real: T) -> Complex {
252240
var c = self
253241
c.real -= real
254242
return c
255243
}
256244

257-
258245
@differentiable(vjp: _vjpAdding(imaginary:) where T : Differentiable, T.TangentVector == T)
259246
public func adding(imaginary: T) -> Complex {
260247
var c = self
261248
c.imaginary += imaginary
262249
return c
263250
}
264251

265-
266252
@differentiable(vjp: _vjpSubtracting(imaginary:) where T : Differentiable, T.TangentVector == T)
267253
public func subtracting(imaginary: T) -> Complex {
268254
var c = self
@@ -271,6 +257,13 @@ extension Complex {
271257
}
272258
}
273259

260+
extension Complex where T : Differentiable, T.TangentVector == T {
261+
@usableFromInline
262+
static func _vjpInit(real: T, imaginary: T) -> (Complex, (Complex) -> (T, T)) {
263+
return (Complex(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) })
264+
}
265+
}
266+
274267
extension Complex where T : Differentiable {
275268
@usableFromInline
276269
static func _vjpAdd(lhs: Complex, rhs: Complex)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
//
4+
// Complex API tests.
5+
6+
// TODO: remove import
7+
import TensorFlow
8+
9+
import StdlibUnittest
10+
11+
var AutoDiffComplexTests = TestSuite("AutoDiffComplex")
12+
13+
AutoDiffComplexTests.test("_vjpAdd") {
14+
let pb: (Complex<Float>) -> Complex<Float> = pullback(at: Complex<Float>(real: 2, imaginary: 3)) { x in
15+
return x + Complex<Float>(real: 5, imaginary: 6)
16+
}
17+
expectEqual(pb(Complex(real: 1, imaginary: 1)), Complex<Float>(real: 1, imaginary: 1))
18+
}
19+
20+
AutoDiffComplexTests.test("_vjpSubtract") {
21+
let pb: (Complex<Float>) -> Complex<Float> = pullback(at: Complex<Float>(real: 2, imaginary: 3)) { x in
22+
return Complex<Float>(real: 5, imaginary: 6) - x
23+
}
24+
expectEqual(pb(Complex(real: 1, imaginary: 1)), Complex<Float>(real: -1, imaginary: -1))
25+
}
26+
27+
AutoDiffComplexTests.test("_vjpMultiply") {
28+
let pb: (Complex<Float>) -> Complex<Float> = pullback(at: Complex<Float>(real: 2, imaginary: 3)) { x in
29+
return x * x
30+
}
31+
expectEqual(pb(Complex(real: 1, imaginary: 0)), Complex<Float>(real: 4, imaginary: 6))
32+
expectEqual(pb(Complex(real: 0, imaginary: 1)), Complex<Float>(real: -6, imaginary: 4))
33+
expectEqual(pb(Complex(real: 1, imaginary: 1)), Complex<Float>(real: -2, imaginary: 10))
34+
}
35+
36+
AutoDiffComplexTests.test("_vjpDivide") {
37+
let pb: (Complex<Float>) -> Complex<Float> = pullback(at: Complex<Float>(real: 20, imaginary: -4)) { x in
38+
return x / Complex<Float>(real: 2, imaginary: 2)
39+
}
40+
expectEqual(pb(Complex(real: 1, imaginary: 0)), Complex<Float>(real: 0.25, imaginary: -0.25))
41+
expectEqual(pb(Complex(real: 0, imaginary: 1)), Complex<Float>(real: 0.25, imaginary: 0.25))
42+
}
43+
44+
AutoDiffComplexTests.test("_vjpNegate") {
45+
let pb: (Complex<Float>) -> Complex<Float> = pullback(at: Complex<Float>(real: 20, imaginary: -4)) { x in
46+
return -x
47+
}
48+
expectEqual(pb(Complex(real: 1, imaginary: 0)), Complex<Float>(real: -1, imaginary: 0))
49+
expectEqual(pb(Complex(real: 0, imaginary: 1)), Complex<Float>(real: 0, imaginary: -1))
50+
expectEqual(pb(Complex(real: 1, imaginary: 1)), Complex<Float>(real: -1, imaginary: -1))
51+
}
52+
53+
AutoDiffComplexTests.test("_vjpComplexConjugate") {
54+
let pb: (Complex<Float>) -> Complex<Float> = pullback(at: Complex<Float>(real: 20, imaginary: -4)) { x in
55+
return x.complexConjugate()
56+
}
57+
expectEqual(pb(Complex(real: 1, imaginary: 0)), Complex<Float>(real: -1, imaginary: 0))
58+
expectEqual(pb(Complex(real: 0, imaginary: 1)), Complex<Float>(real: 0, imaginary: -1))
59+
expectEqual(pb(Complex(real: 1, imaginary: 1)), Complex<Float>(real: -1, imaginary: -1))
60+
}
61+
62+
AutoDiffComplexTests.test("_vjpAdding(real:)") {
63+
let pb: (Complex<Float>) -> Complex<Float> = pullback(at: Complex<Float>(real: 20, imaginary: -4)) { x in
64+
return x.adding(real: 5)
65+
}
66+
expectEqual(pb(Complex(real: 1, imaginary: 0)), Complex<Float>(real: 1, imaginary: 0))
67+
expectEqual(pb(Complex(real: 0, imaginary: 1)), Complex<Float>(real: 0, imaginary: 1))
68+
expectEqual(pb(Complex(real: 1, imaginary: 1)), Complex<Float>(real: 1, imaginary: 1))
69+
}
70+
71+
AutoDiffComplexTests.test("_vjpAdding(imaginary:)") {
72+
let pb: (Complex<Float>) -> Complex<Float> = pullback(at: Complex<Float>(real: 20, imaginary: -4)) { x in
73+
return x.adding(imaginary: 5)
74+
}
75+
expectEqual(pb(Complex(real: 1, imaginary: 0)), Complex<Float>(real: 1, imaginary: 0))
76+
expectEqual(pb(Complex(real: 0, imaginary: 1)), Complex<Float>(real: 0, imaginary: 1))
77+
expectEqual(pb(Complex(real: 1, imaginary: 1)), Complex<Float>(real: 1, imaginary: 1))
78+
}
79+
80+
AutoDiffComplexTests.test("_vjpSubtracting(real:)") {
81+
let pb: (Complex<Float>) -> Complex<Float> = pullback(at: Complex<Float>(real: 20, imaginary: -4)) { x in
82+
return x.subtracting(real: 5)
83+
}
84+
expectEqual(pb(Complex(real: 1, imaginary: 0)), Complex<Float>(real: 1, imaginary: 0))
85+
expectEqual(pb(Complex(real: 0, imaginary: 1)), Complex<Float>(real: 0, imaginary: 1))
86+
expectEqual(pb(Complex(real: 1, imaginary: 1)), Complex<Float>(real: 1, imaginary: 1))
87+
}
88+
89+
AutoDiffComplexTests.test("_vjpSubtracting(imaginary:)") {
90+
let pb: (Complex<Float>) -> Complex<Float> = pullback(at: Complex<Float>(real: 20, imaginary: -4)) { x in
91+
return x.subtracting(imaginary: 5)
92+
}
93+
expectEqual(pb(Complex(real: 1, imaginary: 0)), Complex<Float>(real: 1, imaginary: 0))
94+
expectEqual(pb(Complex(real: 0, imaginary: 1)), Complex<Float>(real: 0, imaginary: 1))
95+
expectEqual(pb(Complex(real: 1, imaginary: 1)), Complex<Float>(real: 1, imaginary: 1))
96+
}
97+
98+
runAllTests()
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
//
4+
// Complex API tests.
5+
6+
// TODO: remove import
7+
import TensorFlow
8+
9+
import StdlibUnittest
10+
11+
var ComplexTests = TestSuite("Complex")
12+
13+
ComplexTests.test("Initializer") {
14+
let complex = Complex<Float>(real: 2, imaginary: 3)
15+
expectEqual(complex.real, 2)
16+
expectEqual(complex.imaginary, 3)
17+
}
18+
19+
ComplexTests.test("Static Imaginary") {
20+
let imaginary = Complex<Float>(real: 0, imaginary: 1)
21+
expectEqual(imaginary, Complex.i)
22+
}
23+
24+
ComplexTests.test("isFinite") {
25+
var complex = Complex<Float>(real: 999, imaginary: 0)
26+
expectTrue(complex.isFinite)
27+
28+
complex = Complex(real: 1.0 / 0.0, imaginary: 1)
29+
expectFalse(complex.isFinite)
30+
31+
complex = Complex(real: 1.0 / 0.0, imaginary: 1.0 / 0.0)
32+
expectFalse(complex.isFinite)
33+
}
34+
35+
ComplexTests.test("isInfinite") {
36+
var complex = Complex<Float>(real: 999, imaginary: 0)
37+
expectFalse(complex.isInfinite)
38+
39+
complex = Complex(real: 1.0 / 0.0, imaginary: 1)
40+
expectTrue(complex.isInfinite)
41+
42+
complex = Complex(real: 1.0 / 0.0, imaginary: 1.0 / 0.0)
43+
expectTrue(complex.isInfinite)
44+
}
45+
46+
ComplexTests.test("isNaN") {
47+
var complex = Complex<Float>(real: 999, imaginary: 0)
48+
expectFalse(complex.isNaN)
49+
50+
complex = Complex(real: 0.0 * 1.0 / 0.0, imaginary: 1)
51+
expectTrue(complex.isNaN)
52+
53+
complex = Complex(real: 0.0 * 1.0 / 0.0, imaginary: 0.0 * 1.0 / 0.0)
54+
expectTrue(complex.isNaN)
55+
}
56+
57+
ComplexTests.test("isZero") {
58+
var complex = Complex<Float>(real: 999, imaginary: 0)
59+
expectFalse(complex.isZero)
60+
61+
complex = Complex(real: 0.0 * 1.0 / 0.0, imaginary: 0)
62+
expectFalse(complex.isZero)
63+
64+
complex = Complex(real: 0.0 * 1.0 / 0.0, imaginary: 0.0 * 1.0 / 0.0)
65+
expectFalse(complex.isZero)
66+
67+
complex = Complex(real: 0, imaginary: 0)
68+
expectTrue(complex.isZero)
69+
}
70+
71+
ComplexTests.test("==") {
72+
var complexA = Complex<Float>(real: 999, imaginary: 0)
73+
let complexB = Complex<Float>(real: 999, imaginary: 0)
74+
expectEqual(complexA, complexB)
75+
76+
complexA = Complex(real: 5, imaginary: 0)
77+
expectNotEqual(complexA, complexB)
78+
}
79+
80+
ComplexTests.test("+") {
81+
let input = Complex<Float>(real: 5, imaginary: 1)
82+
let expected = Complex<Float>(real: 10, imaginary: 2)
83+
expectEqual(expected, input + input)
84+
}
85+
86+
ComplexTests.test("-") {
87+
let inputA = Complex<Float>(real: 6, imaginary: 2)
88+
let inputB = Complex<Float>(real: 5, imaginary: 1)
89+
let expected = Complex<Float>(real: 1, imaginary: 1)
90+
expectEqual(expected, inputA - inputB)
91+
}
92+
93+
ComplexTests.test("*") {
94+
let inputA = Complex<Float>(real: 6, imaginary: 2)
95+
let inputB = Complex<Float>(real: 5, imaginary: 1)
96+
let expected = Complex<Float>(real: 28, imaginary: 16)
97+
expectEqual(expected, inputA * inputB)
98+
}
99+
100+
ComplexTests.test("negate") {
101+
var input = Complex<Float>(real: 6, imaginary: 2)
102+
let negated = Complex<Float>(real: -6, imaginary: -2)
103+
expectEqual(-input, negated)
104+
input.negate()
105+
expectEqual(input, negated)
106+
}
107+
108+
ComplexTests.test("/") {
109+
let inputA = Complex<Float>(real: 20, imaginary: -4)
110+
let inputB = Complex<Float>(real: 3, imaginary: 2)
111+
let expected = Complex<Float>(real: 4, imaginary: -4)
112+
expectEqual(expected, inputA / inputB)
113+
}
114+
115+
ComplexTests.test("complexConjugate") {
116+
var input = Complex<Float>(real: 2, imaginary: -4)
117+
var expected = Complex<Float>(real: 2, imaginary: 4)
118+
expectEqual(expected, input.complexConjugate())
119+
120+
input = Complex<Float>(real: -2, imaginary: -4)
121+
expected = Complex<Float>(real: -2, imaginary: 4)
122+
expectEqual(expected, input.complexConjugate())
123+
124+
input = Complex<Float>(real: 2, imaginary: 4)
125+
expected = Complex<Float>(real: 2, imaginary: -4)
126+
expectEqual(expected, input.complexConjugate())
127+
}
128+
129+
ComplexTests.test("adding") {
130+
var input = Complex<Float>(real: 2, imaginary: -4)
131+
var expected = Complex<Float>(real: 3, imaginary: -4)
132+
expectEqual(expected, input.adding(real: 1))
133+
134+
input = Complex<Float>(real: 2, imaginary: -4)
135+
expected = Complex<Float>(real: 2, imaginary: -3)
136+
expectEqual(expected, input.adding(imaginary: 1))
137+
}
138+
139+
ComplexTests.test("subtracting") {
140+
var input = Complex<Float>(real: 2, imaginary: -4)
141+
var expected = Complex<Float>(real: 1, imaginary: -4)
142+
expectEqual(expected, input.subtracting(real: 1))
143+
144+
input = Complex<Float>(real: 2, imaginary: -4)
145+
expected = Complex<Float>(real: 2, imaginary: -5)
146+
expectEqual(expected, input.subtracting(imaginary: 1))
147+
}
148+
149+
runAllTests()

0 commit comments

Comments
 (0)