-
Notifications
You must be signed in to change notification settings - Fork 137
Add Complex Numbers #129
Add Complex Numbers #129
Changes from 21 commits
ee905cf
9c2d11b
346b9b7
d5e78cf
de575fe
683366d
01efbd9
23870cb
fdd1c50
3e1d56e
1e5758a
e58f400
9e62906
4ac97e3
3350c21
85b43bb
46e68c9
dc4930c
e4b7106
ae53238
09c53d0
1856a2f
90bff29
a7d88bb
57bca41
40522c6
4e2053a
80cab5c
b5cf071
1b4c33a
33b178f
59989bd
4e01fed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,363 @@ | ||
// Copyright 2017-2019 Xiaodi Wu and The TensorFlow Authors. All Rights Reserved. | ||
// | ||
// Licensed under the MIT License (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://opensource.org/licenses/MIT | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
// | ||
// Note | ||
// ==== | ||
// | ||
// For maximum consistency with corresponding functions in C/C++, checks for | ||
// special values in `naturalExponential()`, `squareRoot()`, trigonometric | ||
// functions, and hyperbolic functions are adapted from libc++. | ||
// | ||
// Code in libc++ is dual-licensed under the MIT and UIUC/NCSA licenses. | ||
// Copyright © 2009-2017 contributors to the LLVM/libc++ project. | ||
/// A type to represent a complex value in Cartesian form. | ||
/// | ||
/// - Note: `Complex64` is a type alias for `Complex<Float>` and `Complex128` is | ||
/// a type alias for `Complex<Double>`. | ||
/// | ||
/// Create new instances of `Complex<T>` using integer or floating-point | ||
/// literals and the imaginary unit `Complex<T>.i`. For example: | ||
/// | ||
/// ```swift | ||
/// let x = 2 + 4 * .i // `x` is of type `Complex<Double>` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this will type infer correctly given that there is no contextual complex type There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh yes you're right I would need to add more context like I added this documentation back when @rxwei pointed out that I should keep the original header by Xiaodi. However, I think I skimmed over to quickly when I read the following:
in the google third-party info (and also I didn't closely read the original header and missed this typo you caught). I'll try to keep most of the documentation, however, looking over it, some of it may not be relevant to the current implementation since certain functions are specific to Xiaodi's library. |
||
/// let y = 3.5 + 7 * .i // `y` is of type `Complex<Double>` | ||
/// | ||
/// let z: Complex64 = .e + .pi * .i // `z` is of type `Complex<Float>` | ||
/// ``` | ||
/// | ||
/// Additional Considerations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should add an additional Additional Consideration linking to https://github.com/HIPS/autograd/blob/master/docs/tutorial.md#complex-numbers and explaining that we are using that convention. It's a very non-obvious choice about how to do things. The first time I thought about complex differentiation, I thought that we should only define derivatives for holomorphic functions, but that convention lets us define derivatives for all functions and that link explains why that's a good thing to do. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed I added the following: /// Our implementation of complex number differentiation follows the same
/// convention as Autograd. In short, we can get the derivative of a
/// holomorphic function, functions whose codomain are the Reals, and
/// functions whose codomain and domain are the Reals. You can read more about
/// Autograd at
///
/// https://github.com/HIPS/autograd/blob/master/docs/tutorial.md#complex-numbers |
||
/// ------------------------- | ||
/// | ||
/// Floating-point types have special values that represent infinity or NaN | ||
/// ("not a number"). Complex functions in different languages may return | ||
/// different results when working with special values. | ||
/// | ||
/// Many complex functions have [branch cuts][dfn], which are curves in the | ||
/// complex plane across which a function is discontinuous. Different languages | ||
/// may adopt different branch cut structures for the same complex function. | ||
/// | ||
/// Implementations in `Complex<T>` adhere to the [C standard][std] (Annex G) as | ||
/// closely as possible with respect to special values and branch cuts. | ||
/// | ||
/// To users unfamiliar with complex functions, the principal value returned by | ||
/// some complex functions may be unexpected. For example, | ||
/// `Double.cbrt(-8) == -2`, which is the __real root__, while | ||
/// `Complex.cbrt(-8) == 2 * Complex.exp(.i * .pi / 3)`, which is the | ||
/// __principal root__. | ||
/// | ||
/// [dfn]: http://mathworld.wolfram.com/BranchCut.html | ||
/// [std]: http://www.open-std.org/JTC1/SC22/WG14/www/standards.html#9899 | ||
|
||
struct Complex<T: FloatingPoint> { | ||
var real: T | ||
var imaginary: T | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you tried making In practice, the compiler can't quite derive these yet because most of the complex operations are implemented in terms of mutation and control flow that the compiler can't handle. But soon it should be able to handle it! It would be very interesting to add some tests that define a few operations in ways that autodiff can handle, and then see if autodiff computes the expected derivatives for those. e.g.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe derived conformances already made stored properties differentiable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I did in that when I mark a method as differentiable, I put the requirement that
I'll give it a shot! 😄 |
||
|
||
@differentiable(vjp: _vjpInit where T: Differentiable, T.TangentVector == T) | ||
init(real: T = 0, imaginary: T = 0) { | ||
self.real = real | ||
self.imaginary = imaginary | ||
} | ||
} | ||
|
||
extension Complex: Differentiable where T: Differentiable { | ||
typealias TangentVector = Complex | ||
typealias AllDifferentiableVariables = Complex | ||
} | ||
|
||
extension Complex { | ||
static var i: Complex { | ||
return Complex(real: 0, imaginary: 1) | ||
} | ||
|
||
var isFinite: Bool { | ||
return real.isFinite && imaginary.isFinite | ||
} | ||
|
||
var isInfinite: Bool { | ||
return real.isInfinite || imaginary.isInfinite | ||
} | ||
|
||
var isNaN: Bool { | ||
return (real.isNaN && !imaginary.isInfinite) || | ||
(imaginary.isNaN && !real.isInfinite) | ||
bartchr808 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
var isZero: Bool { | ||
return real.isZero && imaginary.isZero | ||
} | ||
} | ||
|
||
extension Complex: ExpressibleByIntegerLiteral { | ||
init(integerLiteral value: Int) { | ||
self.real = T(value) | ||
self.imaginary = 0 | ||
} | ||
} | ||
|
||
extension Complex: CustomStringConvertible { | ||
var description: String { | ||
return real.isNaN && real.sign == .minus | ||
? imaginary.sign == .minus | ||
? "-\(-real) - \(-imaginary)i" | ||
: "-\(-real) + \(imaginary)i" | ||
: imaginary.sign == .minus | ||
? "\(real) - \(-imaginary)i" | ||
: "\(real) + \(imaginary)i" | ||
} | ||
} | ||
|
||
extension Complex: Equatable { | ||
static func == (lhs: Complex, rhs: Complex) -> Bool { | ||
return lhs.real == rhs.real && lhs.imaginary == rhs.imaginary | ||
} | ||
} | ||
|
||
extension Complex: AdditiveArithmetic { | ||
@differentiable(vjp: _vjpAdd(lhs:rhs:) where T: Differentiable) | ||
static func + (lhs: Complex, rhs: Complex) -> Complex { | ||
var temp = lhs | ||
temp += rhs | ||
return temp | ||
} | ||
|
||
static func += (lhs: inout Complex, rhs: Complex) { | ||
lhs.real += rhs.real | ||
lhs.imaginary += rhs.imaginary | ||
} | ||
|
||
@differentiable(vjp: _vjpSubtract(lhs:rhs:) where T: Differentiable) | ||
static func - (lhs: Complex, rhs: Complex) -> Complex { | ||
var temp = lhs | ||
temp -= rhs | ||
return temp | ||
} | ||
|
||
static func -= (lhs: inout Complex, rhs: Complex) { | ||
lhs.real -= rhs.real | ||
lhs.imaginary -= rhs.imaginary | ||
} | ||
} | ||
|
||
extension Complex: Numeric { | ||
init?<U>(exactly source: U) where U: BinaryInteger { | ||
guard let t = T(exactly: source) else { return nil } | ||
self.real = t | ||
self.imaginary = 0 | ||
} | ||
|
||
@differentiable(vjp: _vjpMultiply(lhs:rhs:) where T: Differentiable) | ||
static func * (lhs: Complex, rhs: Complex) -> Complex { | ||
var a = lhs.real, b = lhs.imaginary, c = rhs.real, d = rhs.imaginary | ||
let ac = a * c, bd = b * d, ad = a * d, bc = b * c | ||
let x = ac - bd | ||
let y = ad + bc | ||
|
||
if x.isNaN && y.isNaN { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this logic would read a lot nicer if the body of this 'if' were pulled out of line into its own function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True it could explain the logic behind this part of the function better. I'll give this a shot Sunday night! |
||
var recalculate = false | ||
if a.isInfinite || b.isInfinite { | ||
a = T(signOf: a, magnitudeOf: a.isInfinite ? 1 : 0) | ||
b = T(signOf: b, magnitudeOf: b.isInfinite ? 1 : 0) | ||
if c.isNaN { c = T(signOf: c, magnitudeOf: 0) } | ||
if d.isNaN { d = T(signOf: d, magnitudeOf: 0) } | ||
recalculate = true | ||
} | ||
if c.isInfinite || d.isInfinite { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm pretty sure this only need to run if 'recalculate' is false. The code structure could be cleaned up (and recalculate eliminated?) if pulled out of line I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the comment above, I'll take a look at this logic more closely and see what I can pull out, what code I can simplify, and also explain with comments Sunday night. As background, I kept this code the way it was just like in the NumericAnnex library. It had written a lot of tests, so I did a bit of a blind trust on the implementation of this logic. Additionally, since I was mainly looking at the automatic differentiation aspect of complex numbers, I didn't spend much time on the implementation of the basic operators. However, since this will be a part of our repo, I see that it's good to make sure all our code is well documented and explained. |
||
if a.isNaN { a = T(signOf: a, magnitudeOf: 0) } | ||
if b.isNaN { b = T(signOf: b, magnitudeOf: 0) } | ||
c = T(signOf: c, magnitudeOf: c.isInfinite ? 1 : 0) | ||
d = T(signOf: d, magnitudeOf: d.isInfinite ? 1 : 0) | ||
recalculate = true | ||
} | ||
if !recalculate && | ||
(ac.isInfinite || bd.isInfinite || ad.isInfinite || bc.isInfinite) { | ||
if a.isNaN { a = T(signOf: a, magnitudeOf: 0) } | ||
if b.isNaN { b = T(signOf: b, magnitudeOf: 0) } | ||
if c.isNaN { c = T(signOf: c, magnitudeOf: 0) } | ||
if d.isNaN { d = T(signOf: d, magnitudeOf: 0) } | ||
recalculate = true | ||
} | ||
if recalculate { | ||
return Complex( | ||
real: .infinity * (a * c - b * d), | ||
imaginary: .infinity * (a * d + b * c) | ||
) | ||
} | ||
} | ||
return Complex(real: x, imaginary: y) | ||
} | ||
|
||
static func *= (lhs: inout Complex, rhs: Complex) { | ||
lhs = lhs * rhs | ||
} | ||
|
||
var magnitude: T { | ||
var x = abs(real) | ||
var y = abs(imaginary) | ||
if x.isInfinite { return x } | ||
if y.isInfinite { return y } | ||
if x == 0 { return y } | ||
if x < y { swap(&x, &y) } | ||
let ratio = y / x | ||
return x * (1 + ratio * ratio).squareRoot() | ||
} | ||
} | ||
|
||
extension Complex: SignedNumeric { | ||
@differentiable(vjp: _vjpNegate where T: Differentiable) | ||
static prefix func - (operand: Complex) -> Complex { | ||
return Complex(real: -operand.real, imaginary: -operand.imaginary) | ||
} | ||
|
||
mutating func negate() { | ||
real.negate() | ||
imaginary.negate() | ||
} | ||
} | ||
|
||
extension Complex { | ||
@differentiable(vjp: _vjpDivide(lhs:rhs:) where T: Differentiable) | ||
static func / (lhs: Complex, rhs: Complex) -> Complex { | ||
var a = lhs.real, b = lhs.imaginary, c = rhs.real, d = rhs.imaginary | ||
var x: T | ||
rxwei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
var y: T | ||
if c.magnitude >= d.magnitude { | ||
let ratio = d / c | ||
let denominator = c + d * ratio | ||
x = (a + b * ratio) / denominator | ||
y = (b - a * ratio) / denominator | ||
} else { | ||
let ratio = c / d | ||
let denominator = c * ratio + d | ||
x = (a * ratio + b) / denominator | ||
y = (b * ratio - a) / denominator | ||
} | ||
if x.isNaN && y.isNaN { | ||
if c == 0 && d == 0 && (!a.isNaN || !b.isNaN) { | ||
x = T(signOf: c, magnitudeOf: .infinity) * a | ||
y = T(signOf: c, magnitudeOf: .infinity) * b | ||
} else if (a.isInfinite || b.isInfinite) && c.isFinite && d.isFinite { | ||
a = T(signOf: a, magnitudeOf: a.isInfinite ? 1 : 0) | ||
b = T(signOf: b, magnitudeOf: b.isInfinite ? 1 : 0) | ||
x = .infinity * (a * c + b * d) | ||
y = .infinity * (b * c - a * d) | ||
} else if (c.isInfinite || d.isInfinite) && a.isFinite && b.isFinite { | ||
c = T(signOf: c, magnitudeOf: c.isInfinite ? 1 : 0) | ||
d = T(signOf: d, magnitudeOf: d.isInfinite ? 1 : 0) | ||
x = 0 * (a * c + b * d) | ||
y = 0 * (b * c - a * d) | ||
} | ||
} | ||
return Complex(real: x, imaginary: y) | ||
} | ||
|
||
static func /= (lhs: inout Complex, rhs: Complex) { | ||
lhs = lhs / rhs | ||
} | ||
} | ||
|
||
extension Complex { | ||
@differentiable(vjp: _vjpComplexConjugate where T: Differentiable) | ||
func complexConjugate() -> Complex { | ||
bartchr808 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Complex(real: real, imaginary: -imaginary) | ||
} | ||
} | ||
|
||
func abs<T>(_ z: Complex<T>) -> Complex<T> { | ||
return Complex(real: z.magnitude) | ||
} | ||
|
||
extension Complex { | ||
@differentiable(vjp: _vjpAdding(real:) where T: Differentiable, T.TangentVector == T) | ||
func adding(real: T) -> Complex { | ||
var c = self | ||
c.real += real | ||
return c | ||
} | ||
|
||
@differentiable(vjp: _vjpSubtracting(real:) where T: Differentiable, T.TangentVector == T) | ||
func subtracting(real: T) -> Complex { | ||
var c = self | ||
c.real -= real | ||
return c | ||
} | ||
|
||
@differentiable(vjp: _vjpAdding(imaginary:) where T: Differentiable, T.TangentVector == T) | ||
func adding(imaginary: T) -> Complex { | ||
var c = self | ||
c.imaginary += imaginary | ||
return c | ||
} | ||
|
||
@differentiable(vjp: _vjpSubtracting(imaginary:) where T: Differentiable, T.TangentVector == T) | ||
func subtracting(imaginary: T) -> Complex { | ||
var c = self | ||
c.imaginary -= imaginary | ||
return c | ||
} | ||
} | ||
|
||
extension Complex where T: Differentiable, T.TangentVector == T { | ||
static func _vjpInit(real: T, imaginary: T) -> (Complex, (Complex) -> (T, T)) { | ||
return (Complex(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) }) | ||
} | ||
} | ||
|
||
extension Complex where T: Differentiable { | ||
static func _vjpAdd(lhs: Complex, rhs: Complex) | ||
-> (Complex, (Complex) -> (Complex, Complex)) { | ||
return (lhs + rhs, { v in (v, v) }) | ||
} | ||
|
||
static func _vjpSubtract(lhs: Complex, rhs: Complex) | ||
-> (Complex, (Complex) -> (Complex, Complex)) { | ||
return (lhs - rhs, { v in (v, -v) }) | ||
} | ||
|
||
static func _vjpMultiply(lhs: Complex, rhs: Complex) | ||
-> (Complex, (Complex) -> (Complex, Complex)) { | ||
return (lhs * rhs, { v in (rhs * v, lhs * v) }) | ||
} | ||
|
||
static func _vjpDivide(lhs: Complex, rhs: Complex) | ||
-> (Complex, (Complex) -> (Complex, Complex)) { | ||
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) | ||
} | ||
|
||
static func _vjpNegate(operand: Complex) | ||
-> (Complex, (Complex) -> Complex) { | ||
return (-operand, { -$0 }) | ||
} | ||
|
||
func _vjpComplexConjugate() -> (Complex, (Complex) -> Complex) { | ||
return (complexConjugate(), { v in v.complexConjugate() }) | ||
} | ||
} | ||
|
||
extension Complex where T: Differentiable, T.TangentVector == T { | ||
func _vjpAdding(real: T) -> (Complex, (Complex) -> (Complex, T)) { | ||
return (self.adding(real: real), { ($0, $0.real) }) | ||
} | ||
|
||
func _vjpSubtracting(real: T) -> (Complex, (Complex) -> (Complex, T)) { | ||
return (self.subtracting(real: real), { ($0, -$0.real) }) | ||
} | ||
|
||
func _vjpAdding(imaginary: T) -> (Complex, (Complex) -> (Complex, T)) { | ||
return (self.adding(real: real), { ($0, $0.imaginary) }) | ||
} | ||
|
||
func _vjpSubtracting(imaginary: T) -> (Complex, (Complex) -> (Complex, T)) { | ||
return (self.subtracting(real: real), { ($0, -$0.imaginary) }) | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.