|
| 1 | +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | +// |
| 15 | +/// Note |
| 16 | +/// ---- |
| 17 | +/// |
| 18 | +/// This implementation uses a modified implementation from the |
| 19 | +/// xwu/NumericAnnex Swift numeric library repo vy Xiaodi Wu. To view the |
| 20 | +/// original code, see the implementation here |
| 21 | +/// |
| 22 | +/// https://github.com/xwu/NumericAnnex/blob/master/Sources/Complex.swift |
| 23 | +/// |
| 24 | +/// Create new instances of `Complex<T>` using integer or floating-point |
| 25 | +/// literals and the imaginary unit `Complex<T>.i`. For example: |
| 26 | +/// |
| 27 | +/// ```swift |
| 28 | +/// let x: Complex<Double> = 2 + 4 * .i |
| 29 | +/// ``` |
| 30 | +/// |
| 31 | +/// Additional Considerations |
| 32 | +/// ------------------------- |
| 33 | +/// |
| 34 | +/// Our implementation of complex number differentiation follows the same |
| 35 | +/// convention as Autograd. In short, we can get the derivative of a |
| 36 | +/// holomorphic function, functions whose codomain are the Reals, and |
| 37 | +/// functions whose codomain and domain are the Reals. You can read more about |
| 38 | +/// Autograd at |
| 39 | +/// |
| 40 | +/// https://github.com/HIPS/autograd/blob/master/docs/tutorial.md#complex-numbers |
| 41 | +/// |
| 42 | +/// Floating-point types have special values that represent infinity or NaN |
| 43 | +/// ("not a number"). Complex functions in different languages may return |
| 44 | +/// different results when working with special values. |
| 45 | + |
| 46 | +struct Complex<T: FloatingPoint> { |
| 47 | + var real: T |
| 48 | + var imaginary: T |
| 49 | + |
| 50 | + @differentiable(vjp: _vjpInit where T: Differentiable, T.TangentVector == T) |
| 51 | + init(real: T = 0, imaginary: T = 0) { |
| 52 | + self.real = real |
| 53 | + self.imaginary = imaginary |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +extension Complex: Differentiable where T: Differentiable { |
| 58 | + typealias TangentVector = Complex |
| 59 | + typealias AllDifferentiableVariables = Complex |
| 60 | +} |
| 61 | + |
| 62 | +extension Complex { |
| 63 | + static var i: Complex { |
| 64 | + return Complex(real: 0, imaginary: 1) |
| 65 | + } |
| 66 | + |
| 67 | + var isFinite: Bool { |
| 68 | + return real.isFinite && imaginary.isFinite |
| 69 | + } |
| 70 | + |
| 71 | + var isInfinite: Bool { |
| 72 | + return real.isInfinite || imaginary.isInfinite |
| 73 | + } |
| 74 | + |
| 75 | + var isNaN: Bool { |
| 76 | + return (real.isNaN && !imaginary.isInfinite) || (imaginary.isNaN && !real.isInfinite) |
| 77 | + } |
| 78 | + |
| 79 | + var isZero: Bool { |
| 80 | + return real.isZero && imaginary.isZero |
| 81 | + } |
| 82 | +} |
| 83 | + |
| 84 | +extension Complex: ExpressibleByIntegerLiteral { |
| 85 | + init(integerLiteral value: Int) { |
| 86 | + self.real = T(value) |
| 87 | + self.imaginary = 0 |
| 88 | + } |
| 89 | +} |
| 90 | + |
| 91 | +extension Complex: CustomStringConvertible { |
| 92 | + var description: String { |
| 93 | + return real.isNaN && real.sign == .minus |
| 94 | + ? imaginary.sign == .minus |
| 95 | + ? "-\(-real) - \(-imaginary)i" |
| 96 | + : "-\(-real) + \(imaginary)i" |
| 97 | + : imaginary.sign == .minus |
| 98 | + ? "\(real) - \(-imaginary)i" |
| 99 | + : "\(real) + \(imaginary)i" |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +extension Complex: Equatable { |
| 104 | + static func == (lhs: Complex, rhs: Complex) -> Bool { |
| 105 | + return lhs.real == rhs.real && lhs.imaginary == rhs.imaginary |
| 106 | + } |
| 107 | +} |
| 108 | + |
| 109 | +extension Complex: AdditiveArithmetic { |
| 110 | + @differentiable(vjp: _vjpAdd(lhs:rhs:) where T: Differentiable) |
| 111 | + static func + (lhs: Complex, rhs: Complex) -> Complex { |
| 112 | + var temp = lhs |
| 113 | + temp += rhs |
| 114 | + return temp |
| 115 | + } |
| 116 | + |
| 117 | + static func += (lhs: inout Complex, rhs: Complex) { |
| 118 | + lhs.real += rhs.real |
| 119 | + lhs.imaginary += rhs.imaginary |
| 120 | + } |
| 121 | + |
| 122 | + @differentiable(vjp: _vjpSubtract(lhs:rhs:) where T: Differentiable) |
| 123 | + static func - (lhs: Complex, rhs: Complex) -> Complex { |
| 124 | + var temp = lhs |
| 125 | + temp -= rhs |
| 126 | + return temp |
| 127 | + } |
| 128 | + |
| 129 | + static func -= (lhs: inout Complex, rhs: Complex) { |
| 130 | + lhs.real -= rhs.real |
| 131 | + lhs.imaginary -= rhs.imaginary |
| 132 | + } |
| 133 | +} |
| 134 | + |
| 135 | +extension Complex: Numeric { |
| 136 | + init?<U>(exactly source: U) where U: BinaryInteger { |
| 137 | + guard let t = T(exactly: source) else { return nil } |
| 138 | + self.real = t |
| 139 | + self.imaginary = 0 |
| 140 | + } |
| 141 | + |
| 142 | + static private func handleMultiplyNaN(infiniteA: T, infiniteB: T, nanA: T, nanB: T) -> Complex { |
| 143 | + var a = infiniteA |
| 144 | + var b = infiniteB |
| 145 | + var c = nanA |
| 146 | + var d = nanB |
| 147 | + |
| 148 | + a = T(signOf: infiniteA, magnitudeOf: infiniteA.isInfinite ? 1 : 0) |
| 149 | + b = T(signOf: infiniteB, magnitudeOf: infiniteB.isInfinite ? 1 : 0) |
| 150 | + |
| 151 | + if nanA.isNaN { c = T(signOf: nanA, magnitudeOf: 0) } |
| 152 | + if nanB.isNaN { d = T(signOf: nanB, magnitudeOf: 0) } |
| 153 | + |
| 154 | + return Complex( |
| 155 | + real: .infinity * (a * c - b * d), |
| 156 | + imaginary: .infinity * (a * d + b * c) |
| 157 | + ) |
| 158 | + } |
| 159 | + |
| 160 | + @differentiable(vjp: _vjpMultiply(lhs:rhs:) where T: Differentiable) |
| 161 | + static func * (lhs: Complex, rhs: Complex) -> Complex { |
| 162 | + var a = lhs.real, b = lhs.imaginary, c = rhs.real, d = rhs.imaginary |
| 163 | + let ac = a * c, bd = b * d, ad = a * d, bc = b * c |
| 164 | + let x = ac - bd |
| 165 | + let y = ad + bc |
| 166 | + |
| 167 | + if x.isNaN && y.isNaN { |
| 168 | + if a.isInfinite || b.isInfinite { |
| 169 | + return handleMultiplyNaN(infiniteA: a, infiniteB: b, nanA: c, nanB: d) |
| 170 | + } else if c.isInfinite || d.isInfinite { |
| 171 | + return handleMultiplyNaN(infiniteA: c, infiniteB: d, nanA: a, nanB: b) |
| 172 | + } else if ac.isInfinite || bd.isInfinite || ad.isInfinite || bc.isInfinite { |
| 173 | + if a.isNaN { a = T(signOf: a, magnitudeOf: 0) } |
| 174 | + if b.isNaN { b = T(signOf: b, magnitudeOf: 0) } |
| 175 | + if c.isNaN { c = T(signOf: c, magnitudeOf: 0) } |
| 176 | + if d.isNaN { d = T(signOf: d, magnitudeOf: 0) } |
| 177 | + return Complex( |
| 178 | + real: .infinity * (a * c - b * d), |
| 179 | + imaginary: .infinity * (a * d + b * c) |
| 180 | + ) |
| 181 | + } |
| 182 | + } |
| 183 | + return Complex(real: x, imaginary: y) |
| 184 | + } |
| 185 | + |
| 186 | + static func *= (lhs: inout Complex, rhs: Complex) { |
| 187 | + lhs = lhs * rhs |
| 188 | + } |
| 189 | + |
| 190 | + var magnitude: T { |
| 191 | + var x = abs(real) |
| 192 | + var y = abs(imaginary) |
| 193 | + if x.isInfinite { return x } |
| 194 | + if y.isInfinite { return y } |
| 195 | + if x == 0 { return y } |
| 196 | + if x < y { swap(&x, &y) } |
| 197 | + let ratio = y / x |
| 198 | + return x * (1 + ratio * ratio).squareRoot() |
| 199 | + } |
| 200 | +} |
| 201 | + |
| 202 | +extension Complex: SignedNumeric { |
| 203 | + @differentiable(vjp: _vjpNegate where T: Differentiable) |
| 204 | + static prefix func - (operand: Complex) -> Complex { |
| 205 | + return Complex(real: -operand.real, imaginary: -operand.imaginary) |
| 206 | + } |
| 207 | + |
| 208 | + mutating func negate() { |
| 209 | + real.negate() |
| 210 | + imaginary.negate() |
| 211 | + } |
| 212 | +} |
| 213 | + |
| 214 | +extension Complex { |
| 215 | + @differentiable(vjp: _vjpDivide(lhs:rhs:) where T: Differentiable) |
| 216 | + static func / (lhs: Complex, rhs: Complex) -> Complex { |
| 217 | + var a = lhs.real, b = lhs.imaginary, c = rhs.real, d = rhs.imaginary |
| 218 | + var x: T |
| 219 | + var y: T |
| 220 | + if c.magnitude >= d.magnitude { |
| 221 | + let ratio = d / c |
| 222 | + let denominator = c + d * ratio |
| 223 | + x = (a + b * ratio) / denominator |
| 224 | + y = (b - a * ratio) / denominator |
| 225 | + } else { |
| 226 | + let ratio = c / d |
| 227 | + let denominator = c * ratio + d |
| 228 | + x = (a * ratio + b) / denominator |
| 229 | + y = (b * ratio - a) / denominator |
| 230 | + } |
| 231 | + if x.isNaN && y.isNaN { |
| 232 | + if c == 0 && d == 0 && (!a.isNaN || !b.isNaN) { |
| 233 | + x = T(signOf: c, magnitudeOf: .infinity) * a |
| 234 | + y = T(signOf: c, magnitudeOf: .infinity) * b |
| 235 | + } else if (a.isInfinite || b.isInfinite) && c.isFinite && d.isFinite { |
| 236 | + a = T(signOf: a, magnitudeOf: a.isInfinite ? 1 : 0) |
| 237 | + b = T(signOf: b, magnitudeOf: b.isInfinite ? 1 : 0) |
| 238 | + x = .infinity * (a * c + b * d) |
| 239 | + y = .infinity * (b * c - a * d) |
| 240 | + } else if (c.isInfinite || d.isInfinite) && a.isFinite && b.isFinite { |
| 241 | + c = T(signOf: c, magnitudeOf: c.isInfinite ? 1 : 0) |
| 242 | + d = T(signOf: d, magnitudeOf: d.isInfinite ? 1 : 0) |
| 243 | + x = 0 * (a * c + b * d) |
| 244 | + y = 0 * (b * c - a * d) |
| 245 | + } |
| 246 | + } |
| 247 | + return Complex(real: x, imaginary: y) |
| 248 | + } |
| 249 | + |
| 250 | + static func /= (lhs: inout Complex, rhs: Complex) { |
| 251 | + lhs = lhs / rhs |
| 252 | + } |
| 253 | +} |
| 254 | + |
| 255 | +extension Complex { |
| 256 | + @differentiable(vjp: _vjpComplexConjugate where T: Differentiable) |
| 257 | + func complexConjugate() -> Complex { |
| 258 | + return Complex(real: real, imaginary: -imaginary) |
| 259 | + } |
| 260 | +} |
| 261 | + |
| 262 | +func abs<T>(_ z: Complex<T>) -> Complex<T> { |
| 263 | + return Complex(real: z.magnitude) |
| 264 | +} |
| 265 | + |
| 266 | +extension Complex { |
| 267 | + @differentiable(vjp: _vjpAdding(real:) where T: Differentiable, T.TangentVector == T) |
| 268 | + func adding(real: T) -> Complex { |
| 269 | + var c = self |
| 270 | + c.real += real |
| 271 | + return c |
| 272 | + } |
| 273 | + |
| 274 | + @differentiable(vjp: _vjpSubtracting(real:) where T: Differentiable, T.TangentVector == T) |
| 275 | + func subtracting(real: T) -> Complex { |
| 276 | + var c = self |
| 277 | + c.real -= real |
| 278 | + return c |
| 279 | + } |
| 280 | + |
| 281 | + @differentiable(vjp: _vjpAdding(imaginary:) where T: Differentiable, T.TangentVector == T) |
| 282 | + func adding(imaginary: T) -> Complex { |
| 283 | + var c = self |
| 284 | + c.imaginary += imaginary |
| 285 | + return c |
| 286 | + } |
| 287 | + |
| 288 | + @differentiable(vjp: _vjpSubtracting(imaginary:) where T: Differentiable, T.TangentVector == T) |
| 289 | + func subtracting(imaginary: T) -> Complex { |
| 290 | + var c = self |
| 291 | + c.imaginary -= imaginary |
| 292 | + return c |
| 293 | + } |
| 294 | +} |
| 295 | + |
| 296 | +extension Complex where T: Differentiable, T.TangentVector == T { |
| 297 | + static func _vjpInit(real: T, imaginary: T) -> (Complex, (Complex) -> (T, T)) { |
| 298 | + return (Complex(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) }) |
| 299 | + } |
| 300 | +} |
| 301 | + |
| 302 | +extension Complex where T: Differentiable { |
| 303 | + static func _vjpAdd(lhs: Complex, rhs: Complex) |
| 304 | + -> (Complex, (Complex) -> (Complex, Complex)) { |
| 305 | + return (lhs + rhs, { v in (v, v) }) |
| 306 | + } |
| 307 | + |
| 308 | + static func _vjpSubtract(lhs: Complex, rhs: Complex) |
| 309 | + -> (Complex, (Complex) -> (Complex, Complex)) { |
| 310 | + return (lhs - rhs, { v in (v, -v) }) |
| 311 | + } |
| 312 | + |
| 313 | + static func _vjpMultiply(lhs: Complex, rhs: Complex) |
| 314 | + -> (Complex, (Complex) -> (Complex, Complex)) { |
| 315 | + return (lhs * rhs, { v in (rhs * v, lhs * v) }) |
| 316 | + } |
| 317 | + |
| 318 | + static func _vjpDivide(lhs: Complex, rhs: Complex) |
| 319 | + -> (Complex, (Complex) -> (Complex, Complex)) { |
| 320 | + return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) |
| 321 | + } |
| 322 | + |
| 323 | + static func _vjpNegate(operand: Complex) |
| 324 | + -> (Complex, (Complex) -> Complex) { |
| 325 | + return (-operand, { -$0 }) |
| 326 | + } |
| 327 | + |
| 328 | + func _vjpComplexConjugate() -> (Complex, (Complex) -> Complex) { |
| 329 | + return (complexConjugate(), { v in v.complexConjugate() }) |
| 330 | + } |
| 331 | +} |
| 332 | + |
| 333 | +extension Complex where T: Differentiable, T.TangentVector == T { |
| 334 | + func _vjpAdding(real: T) -> (Complex, (Complex) -> (Complex, T)) { |
| 335 | + return (self.adding(real: real), { ($0, $0.real) }) |
| 336 | + } |
| 337 | + |
| 338 | + func _vjpSubtracting(real: T) -> (Complex, (Complex) -> (Complex, T)) { |
| 339 | + return (self.subtracting(real: real), { ($0, -$0.real) }) |
| 340 | + } |
| 341 | + |
| 342 | + func _vjpAdding(imaginary: T) -> (Complex, (Complex) -> (Complex, T)) { |
| 343 | + return (self.adding(real: real), { ($0, $0.imaginary) }) |
| 344 | + } |
| 345 | + |
| 346 | + func _vjpSubtracting(imaginary: T) -> (Complex, (Complex) -> (Complex, T)) { |
| 347 | + return (self.subtracting(real: real), { ($0, -$0.imaginary) }) |
| 348 | + } |
| 349 | +} |
0 commit comments