Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c66cf59

Browse files
authored
Add Complex Numbers (#129)
Experimental implementation of differentiating complex numbers, not part of the TensorFlow library.
1 parent 2114912 commit c66cf59

File tree

4 files changed

+731
-0
lines changed

4 files changed

+731
-0
lines changed

Package.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ let package = Package(
3030
.target(
3131
name: "TensorFlow",
3232
dependencies: []),
33+
.target(
34+
name: "Experimental",
35+
dependencies: [],
36+
path: "Sources/third_party/Experimental"),
37+
.testTarget(
38+
name: "ExperimentalTests",
39+
dependencies: ["Experimental"]),
3340
.testTarget(
3441
name: "TensorFlowTests",
3542
dependencies: ["TensorFlow"]),
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
/// Note
16+
/// ----
17+
///
18+
/// This implementation uses a modified implementation from the
19+
/// xwu/NumericAnnex Swift numeric library repo vy Xiaodi Wu. To view the
20+
/// original code, see the implementation here
21+
///
22+
/// https://github.com/xwu/NumericAnnex/blob/master/Sources/Complex.swift
23+
///
24+
/// Create new instances of `Complex<T>` using integer or floating-point
25+
/// literals and the imaginary unit `Complex<T>.i`. For example:
26+
///
27+
/// ```swift
28+
/// let x: Complex<Double> = 2 + 4 * .i
29+
/// ```
30+
///
31+
/// Additional Considerations
32+
/// -------------------------
33+
///
34+
/// Our implementation of complex number differentiation follows the same
35+
/// convention as Autograd. In short, we can get the derivative of a
36+
/// holomorphic function, functions whose codomain are the Reals, and
37+
/// functions whose codomain and domain are the Reals. You can read more about
38+
/// Autograd at
39+
///
40+
/// https://github.com/HIPS/autograd/blob/master/docs/tutorial.md#complex-numbers
41+
///
42+
/// Floating-point types have special values that represent infinity or NaN
43+
/// ("not a number"). Complex functions in different languages may return
44+
/// different results when working with special values.
45+
46+
struct Complex<T: FloatingPoint> {
47+
var real: T
48+
var imaginary: T
49+
50+
@differentiable(vjp: _vjpInit where T: Differentiable, T.TangentVector == T)
51+
init(real: T = 0, imaginary: T = 0) {
52+
self.real = real
53+
self.imaginary = imaginary
54+
}
55+
}
56+
57+
extension Complex: Differentiable where T: Differentiable {
58+
typealias TangentVector = Complex
59+
typealias AllDifferentiableVariables = Complex
60+
}
61+
62+
extension Complex {
63+
static var i: Complex {
64+
return Complex(real: 0, imaginary: 1)
65+
}
66+
67+
var isFinite: Bool {
68+
return real.isFinite && imaginary.isFinite
69+
}
70+
71+
var isInfinite: Bool {
72+
return real.isInfinite || imaginary.isInfinite
73+
}
74+
75+
var isNaN: Bool {
76+
return (real.isNaN && !imaginary.isInfinite) || (imaginary.isNaN && !real.isInfinite)
77+
}
78+
79+
var isZero: Bool {
80+
return real.isZero && imaginary.isZero
81+
}
82+
}
83+
84+
extension Complex: ExpressibleByIntegerLiteral {
85+
init(integerLiteral value: Int) {
86+
self.real = T(value)
87+
self.imaginary = 0
88+
}
89+
}
90+
91+
extension Complex: CustomStringConvertible {
92+
var description: String {
93+
return real.isNaN && real.sign == .minus
94+
? imaginary.sign == .minus
95+
? "-\(-real) - \(-imaginary)i"
96+
: "-\(-real) + \(imaginary)i"
97+
: imaginary.sign == .minus
98+
? "\(real) - \(-imaginary)i"
99+
: "\(real) + \(imaginary)i"
100+
}
101+
}
102+
103+
extension Complex: Equatable {
104+
static func == (lhs: Complex, rhs: Complex) -> Bool {
105+
return lhs.real == rhs.real && lhs.imaginary == rhs.imaginary
106+
}
107+
}
108+
109+
extension Complex: AdditiveArithmetic {
110+
@differentiable(vjp: _vjpAdd(lhs:rhs:) where T: Differentiable)
111+
static func + (lhs: Complex, rhs: Complex) -> Complex {
112+
var temp = lhs
113+
temp += rhs
114+
return temp
115+
}
116+
117+
static func += (lhs: inout Complex, rhs: Complex) {
118+
lhs.real += rhs.real
119+
lhs.imaginary += rhs.imaginary
120+
}
121+
122+
@differentiable(vjp: _vjpSubtract(lhs:rhs:) where T: Differentiable)
123+
static func - (lhs: Complex, rhs: Complex) -> Complex {
124+
var temp = lhs
125+
temp -= rhs
126+
return temp
127+
}
128+
129+
static func -= (lhs: inout Complex, rhs: Complex) {
130+
lhs.real -= rhs.real
131+
lhs.imaginary -= rhs.imaginary
132+
}
133+
}
134+
135+
extension Complex: Numeric {
136+
init?<U>(exactly source: U) where U: BinaryInteger {
137+
guard let t = T(exactly: source) else { return nil }
138+
self.real = t
139+
self.imaginary = 0
140+
}
141+
142+
static private func handleMultiplyNaN(infiniteA: T, infiniteB: T, nanA: T, nanB: T) -> Complex {
143+
var a = infiniteA
144+
var b = infiniteB
145+
var c = nanA
146+
var d = nanB
147+
148+
a = T(signOf: infiniteA, magnitudeOf: infiniteA.isInfinite ? 1 : 0)
149+
b = T(signOf: infiniteB, magnitudeOf: infiniteB.isInfinite ? 1 : 0)
150+
151+
if nanA.isNaN { c = T(signOf: nanA, magnitudeOf: 0) }
152+
if nanB.isNaN { d = T(signOf: nanB, magnitudeOf: 0) }
153+
154+
return Complex(
155+
real: .infinity * (a * c - b * d),
156+
imaginary: .infinity * (a * d + b * c)
157+
)
158+
}
159+
160+
@differentiable(vjp: _vjpMultiply(lhs:rhs:) where T: Differentiable)
161+
static func * (lhs: Complex, rhs: Complex) -> Complex {
162+
var a = lhs.real, b = lhs.imaginary, c = rhs.real, d = rhs.imaginary
163+
let ac = a * c, bd = b * d, ad = a * d, bc = b * c
164+
let x = ac - bd
165+
let y = ad + bc
166+
167+
if x.isNaN && y.isNaN {
168+
if a.isInfinite || b.isInfinite {
169+
return handleMultiplyNaN(infiniteA: a, infiniteB: b, nanA: c, nanB: d)
170+
} else if c.isInfinite || d.isInfinite {
171+
return handleMultiplyNaN(infiniteA: c, infiniteB: d, nanA: a, nanB: b)
172+
} else if ac.isInfinite || bd.isInfinite || ad.isInfinite || bc.isInfinite {
173+
if a.isNaN { a = T(signOf: a, magnitudeOf: 0) }
174+
if b.isNaN { b = T(signOf: b, magnitudeOf: 0) }
175+
if c.isNaN { c = T(signOf: c, magnitudeOf: 0) }
176+
if d.isNaN { d = T(signOf: d, magnitudeOf: 0) }
177+
return Complex(
178+
real: .infinity * (a * c - b * d),
179+
imaginary: .infinity * (a * d + b * c)
180+
)
181+
}
182+
}
183+
return Complex(real: x, imaginary: y)
184+
}
185+
186+
static func *= (lhs: inout Complex, rhs: Complex) {
187+
lhs = lhs * rhs
188+
}
189+
190+
var magnitude: T {
191+
var x = abs(real)
192+
var y = abs(imaginary)
193+
if x.isInfinite { return x }
194+
if y.isInfinite { return y }
195+
if x == 0 { return y }
196+
if x < y { swap(&x, &y) }
197+
let ratio = y / x
198+
return x * (1 + ratio * ratio).squareRoot()
199+
}
200+
}
201+
202+
extension Complex: SignedNumeric {
203+
@differentiable(vjp: _vjpNegate where T: Differentiable)
204+
static prefix func - (operand: Complex) -> Complex {
205+
return Complex(real: -operand.real, imaginary: -operand.imaginary)
206+
}
207+
208+
mutating func negate() {
209+
real.negate()
210+
imaginary.negate()
211+
}
212+
}
213+
214+
extension Complex {
215+
@differentiable(vjp: _vjpDivide(lhs:rhs:) where T: Differentiable)
216+
static func / (lhs: Complex, rhs: Complex) -> Complex {
217+
var a = lhs.real, b = lhs.imaginary, c = rhs.real, d = rhs.imaginary
218+
var x: T
219+
var y: T
220+
if c.magnitude >= d.magnitude {
221+
let ratio = d / c
222+
let denominator = c + d * ratio
223+
x = (a + b * ratio) / denominator
224+
y = (b - a * ratio) / denominator
225+
} else {
226+
let ratio = c / d
227+
let denominator = c * ratio + d
228+
x = (a * ratio + b) / denominator
229+
y = (b * ratio - a) / denominator
230+
}
231+
if x.isNaN && y.isNaN {
232+
if c == 0 && d == 0 && (!a.isNaN || !b.isNaN) {
233+
x = T(signOf: c, magnitudeOf: .infinity) * a
234+
y = T(signOf: c, magnitudeOf: .infinity) * b
235+
} else if (a.isInfinite || b.isInfinite) && c.isFinite && d.isFinite {
236+
a = T(signOf: a, magnitudeOf: a.isInfinite ? 1 : 0)
237+
b = T(signOf: b, magnitudeOf: b.isInfinite ? 1 : 0)
238+
x = .infinity * (a * c + b * d)
239+
y = .infinity * (b * c - a * d)
240+
} else if (c.isInfinite || d.isInfinite) && a.isFinite && b.isFinite {
241+
c = T(signOf: c, magnitudeOf: c.isInfinite ? 1 : 0)
242+
d = T(signOf: d, magnitudeOf: d.isInfinite ? 1 : 0)
243+
x = 0 * (a * c + b * d)
244+
y = 0 * (b * c - a * d)
245+
}
246+
}
247+
return Complex(real: x, imaginary: y)
248+
}
249+
250+
static func /= (lhs: inout Complex, rhs: Complex) {
251+
lhs = lhs / rhs
252+
}
253+
}
254+
255+
extension Complex {
256+
@differentiable(vjp: _vjpComplexConjugate where T: Differentiable)
257+
func complexConjugate() -> Complex {
258+
return Complex(real: real, imaginary: -imaginary)
259+
}
260+
}
261+
262+
func abs<T>(_ z: Complex<T>) -> Complex<T> {
263+
return Complex(real: z.magnitude)
264+
}
265+
266+
extension Complex {
267+
@differentiable(vjp: _vjpAdding(real:) where T: Differentiable, T.TangentVector == T)
268+
func adding(real: T) -> Complex {
269+
var c = self
270+
c.real += real
271+
return c
272+
}
273+
274+
@differentiable(vjp: _vjpSubtracting(real:) where T: Differentiable, T.TangentVector == T)
275+
func subtracting(real: T) -> Complex {
276+
var c = self
277+
c.real -= real
278+
return c
279+
}
280+
281+
@differentiable(vjp: _vjpAdding(imaginary:) where T: Differentiable, T.TangentVector == T)
282+
func adding(imaginary: T) -> Complex {
283+
var c = self
284+
c.imaginary += imaginary
285+
return c
286+
}
287+
288+
@differentiable(vjp: _vjpSubtracting(imaginary:) where T: Differentiable, T.TangentVector == T)
289+
func subtracting(imaginary: T) -> Complex {
290+
var c = self
291+
c.imaginary -= imaginary
292+
return c
293+
}
294+
}
295+
296+
extension Complex where T: Differentiable, T.TangentVector == T {
297+
static func _vjpInit(real: T, imaginary: T) -> (Complex, (Complex) -> (T, T)) {
298+
return (Complex(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) })
299+
}
300+
}
301+
302+
extension Complex where T: Differentiable {
303+
static func _vjpAdd(lhs: Complex, rhs: Complex)
304+
-> (Complex, (Complex) -> (Complex, Complex)) {
305+
return (lhs + rhs, { v in (v, v) })
306+
}
307+
308+
static func _vjpSubtract(lhs: Complex, rhs: Complex)
309+
-> (Complex, (Complex) -> (Complex, Complex)) {
310+
return (lhs - rhs, { v in (v, -v) })
311+
}
312+
313+
static func _vjpMultiply(lhs: Complex, rhs: Complex)
314+
-> (Complex, (Complex) -> (Complex, Complex)) {
315+
return (lhs * rhs, { v in (rhs * v, lhs * v) })
316+
}
317+
318+
static func _vjpDivide(lhs: Complex, rhs: Complex)
319+
-> (Complex, (Complex) -> (Complex, Complex)) {
320+
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
321+
}
322+
323+
static func _vjpNegate(operand: Complex)
324+
-> (Complex, (Complex) -> Complex) {
325+
return (-operand, { -$0 })
326+
}
327+
328+
func _vjpComplexConjugate() -> (Complex, (Complex) -> Complex) {
329+
return (complexConjugate(), { v in v.complexConjugate() })
330+
}
331+
}
332+
333+
extension Complex where T: Differentiable, T.TangentVector == T {
334+
func _vjpAdding(real: T) -> (Complex, (Complex) -> (Complex, T)) {
335+
return (self.adding(real: real), { ($0, $0.real) })
336+
}
337+
338+
func _vjpSubtracting(real: T) -> (Complex, (Complex) -> (Complex, T)) {
339+
return (self.subtracting(real: real), { ($0, -$0.real) })
340+
}
341+
342+
func _vjpAdding(imaginary: T) -> (Complex, (Complex) -> (Complex, T)) {
343+
return (self.adding(real: real), { ($0, $0.imaginary) })
344+
}
345+
346+
func _vjpSubtracting(imaginary: T) -> (Complex, (Complex) -> (Complex, T)) {
347+
return (self.subtracting(real: real), { ($0, -$0.imaginary) })
348+
}
349+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.

0 commit comments

Comments
 (0)