Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit d5e78cf

Browse files
committed
Loosen constraints on complex number.
1 parent 346b9b7 commit d5e78cf

File tree

1 file changed

+47
-89
lines changed

1 file changed

+47
-89
lines changed

Sources/TensorFlow/Core/Complex.swift

Lines changed: 47 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,31 @@
11
import TensorFlow
2-
3-
//
4-
// Complex.swift
5-
// NumericAnnex
6-
//
7-
// Created by Xiaodi Wu on 3/25/17.
8-
//
9-
// Note
10-
// ====
11-
//
12-
// For maximum consistency with corresponding functions in C/C++, checks for
13-
// special values in `naturalExponential()`, `squareRoot()`, trigonometric
14-
// functions, and hyperbolic functions are adapted from libc++.
15-
//
16-
// Code in libc++ is dual-licensed under the MIT and UIUC/NCSA licenses.
17-
// Copyright © 2009-2017 contributors to the LLVM/libc++ project.
18-
19-
/// A type to represent a complex value in Cartesian form.
20-
///
21-
/// Additional Considerations
22-
/// -------------------------
23-
///
24-
/// Floating-point types have special values that represent infinity or NaN
25-
/// ("not a number"). Complex functions in different languages may return
26-
/// different results when working with special values.
27-
///
28-
/// Implementations in `Complex<T>` adhere to the [C standard][std] (Annex G) as
29-
/// closely as possible with respect to special values and branch cuts.
30-
///
31-
/// To users unfamiliar with complex functions, the principal value returned by
32-
/// some complex functions may be unexpected. For example,
33-
/// `Double.cbrt(-8) == -2`, which is the __real root__, while
34-
/// `Complex.cbrt(-8) == 2 * Complex.exp(.i * .pi / 3)`, which is the
35-
/// __principal root__.
36-
///
37-
/// [dfn]: http://mathworld.wolfram.com/BranchCut.html
38-
/// [std]: http://www.open-std.org/JTC1/SC22/WG14/www/standards.html#9899
39-
@_fixed_layout
40-
public struct Complex<T : TensorFlowFloatingPoint> : Differentiable {
2+
// T : FloatingPoint & Differentiable
3+
public struct Complex<T : FloatingPoint> {
414
// ---------------------------------------------------------------------------
425
// MARK: Stored Properties
436
// ---------------------------------------------------------------------------
447

458
/// The real component of the complex value.
46-
// @differentiable(vjp: _vjpReal)
479
public var real: T
4810

4911
/// The imaginary component of the complex value.
50-
// @differentiable(vjp: _vjpImaginary)
5112
public var imaginary: T
5213

5314
// ---------------------------------------------------------------------------
5415
// MARK: Initializers
5516
// ---------------------------------------------------------------------------
56-
@differentiable(wrt: (real, imaginary), vjp: _vjpInit)
5717
public init(real: T = 0, imaginary: T = 0) {
5818
self.real = real
5919
self.imaginary = imaginary
6020
}
21+
}
6122

23+
extension Complex : Differentiable where T : Differentiable, T.TangentVector == T {
24+
// ---------------------------------------------------------------------------
25+
// MARK: Differentiability
26+
// ---------------------------------------------------------------------------
6227
public typealias TangentVector = Complex
63-
public typealias CotangentVector = Complex
6428
public typealias AllDifferentiableVariables = Complex
65-
public func tangentVector(from cotangent: CotangentVector) -> TangentVector {
66-
return cotangent
67-
}
6829
}
6930

7031
extension Complex {
@@ -73,7 +34,7 @@ extension Complex {
7334
// ---------------------------------------------------------------------------
7435

7536
/// The imaginary unit _i_.
76-
@_transparent // @_inlineable
37+
@inlinable
7738
public static var i: Complex {
7839
return Complex(real: 0, imaginary: 1)
7940
}
@@ -82,7 +43,7 @@ extension Complex {
8243
///
8344
/// A complex value is finite if its real and imaginary components are both
8445
/// finite. A component is finite if it is not infinity or NaN.
85-
@_transparent // @_inlineable
46+
@inlinable
8647
public var isFinite: Bool {
8748
return real.isFinite && imaginary.isFinite
8849
}
@@ -94,7 +55,7 @@ extension Complex {
9455
///
9556
/// Note that `isFinite` and `isInfinite` do not form a dichotomy because NaN
9657
/// is neither finite nor infinite.
97-
@_transparent // @_inlineable
58+
@inlinable
9859
public var isInfinite: Bool {
9960
return real.isInfinite || imaginary.isInfinite
10061
}
@@ -109,7 +70,7 @@ extension Complex {
10970
/// test whether a value is or is not NaN.
11071
///
11172
/// This property is `true` for both quiet and signaling NaNs.
112-
@_transparent // @_inlineable
73+
@inlinable
11374
public var isNaN: Bool {
11475
return (real.isNaN && !imaginary.isInfinite) ||
11576
(imaginary.isNaN && !real.isInfinite)
@@ -119,7 +80,7 @@ extension Complex {
11980
///
12081
/// A complex value is equal to zero if its real and imaginary components both
12182
/// represent either `-0.0` or `+0.0`.
122-
@_transparent // @_inlineable
83+
@inlinable
12384
public var isZero: Bool {
12485
return real.isZero && imaginary.isZero
12586
}
@@ -130,7 +91,7 @@ extension Complex : ExpressibleByIntegerLiteral {
13091
// MARK: ExpressibleByIntegerLiteral
13192
// ---------------------------------------------------------------------------
13293

133-
@_transparent // @_inlineable
94+
@inlinable
13495
public init(integerLiteral value: Int) {
13596
self.real = T(value)
13697
self.imaginary = 0
@@ -142,7 +103,7 @@ extension Complex : CustomStringConvertible {
142103
// MARK: CustomStringConvertible
143104
// ---------------------------------------------------------------------------
144105

145-
@_transparent // @_inlineable
106+
@inlinable
146107
public var description: String {
147108
return real.isNaN && real.sign == .minus
148109
// At present, -NaN is described as "nan", which is acceptable for real
@@ -162,7 +123,7 @@ extension Complex : Equatable {
162123
// MARK: Equatable
163124
// ---------------------------------------------------------------------------
164125

165-
@_transparent // @_inlineable
126+
@inlinable
166127
public static func == (lhs: Complex, rhs: Complex) -> Bool {
167128
return lhs.real == rhs.real && lhs.imaginary == rhs.imaginary
168129
}
@@ -173,29 +134,29 @@ extension Complex : AdditiveArithmetic {
173134
// MARK: AdditiveArithmetic
174135
// ---------------------------------------------------------------------------
175136

176-
@_transparent // @_inlineable
177-
@differentiable(vjp: _vjpAdd(lhs:rhs:) where T : Differentiable)
137+
@inlinable
138+
@differentiable(vjp: _vjpAdd(lhs:rhs:) where T : Differentiable, T.TangentVector == T)
178139
public static func + (lhs: Complex, rhs: Complex) -> Complex {
179140
var lhs = lhs
180141
lhs += rhs
181142
return lhs
182143
}
183144

184-
@_transparent // @_inlineable
145+
@inlinable
185146
public static func += (lhs: inout Complex, rhs: Complex) {
186147
lhs.real += rhs.real
187148
lhs.imaginary += rhs.imaginary
188149
}
189150

190-
@_transparent // @_inlineable
191-
@differentiable(vjp: _vjpSubtract(lhs:rhs:) where T : Differentiable)
151+
@inlinable
152+
@differentiable(vjp: _vjpSubtract(lhs:rhs:) where T : Differentiable, T.TangentVector == T)
192153
public static func - (lhs: Complex, rhs: Complex) -> Complex {
193154
var lhs = lhs
194155
lhs -= rhs
195156
return lhs
196157
}
197158

198-
@_transparent // @_inlineable
159+
@inlinable
199160
public static func -= (lhs: inout Complex, rhs: Complex) {
200161
lhs.real -= rhs.real
201162
lhs.imaginary -= rhs.imaginary
@@ -213,8 +174,8 @@ extension Complex : Numeric {
213174
self.imaginary = 0
214175
}
215176

216-
@_transparent // @_inlineable
217-
@differentiable(vjp: _vjpMultiply(lhs:rhs:) where T : Differentiable)
177+
@inlinable
178+
@differentiable(vjp: _vjpMultiply(lhs:rhs:) where T : Differentiable, T.TangentVector == T)
218179
public static func * (lhs: Complex, rhs: Complex) -> Complex {
219180
var a = lhs.real, b = lhs.imaginary, c = rhs.real, d = rhs.imaginary
220181
let ac = a * c, bd = b * d, ad = a * d, bc = b * c
@@ -259,12 +220,12 @@ extension Complex : Numeric {
259220
return Complex(real: x, imaginary: y)
260221
}
261222

262-
@_transparent // @_inlineable
223+
@inlinable
263224
public static func *= (lhs: inout Complex, rhs: Complex) {
264225
lhs = lhs * rhs
265226
}
266227

267-
@_transparent // @_inlineable
228+
@inlinable
268229
public var magnitude: T {
269230
var x = abs(real)
270231
var y = abs(imaginary)
@@ -282,23 +243,26 @@ extension Complex : SignedNumeric {
282243
// MARK: SignedNumeric
283244
// ---------------------------------------------------------------------------
284245

285-
@_transparent // @_inlineable
286-
@differentiable(vjp: _vjpNegate where T : Differentiable)
246+
@inlinable
247+
@differentiable(vjp: _vjpNegate where T : Differentiable, T.TangentVector == T)
287248
public static prefix func - (operand: Complex) -> Complex {
288249
return Complex(real: -operand.real, imaginary: -operand.imaginary)
289250
}
290251

291-
@_transparent // @_inlineable
252+
@inlinable
292253
public mutating func negate() {
293254
real.negate()
294255
imaginary.negate()
295256
}
296257
}
297258

298259
extension Complex {
260+
// ---------------------------------------------------------------------------
261+
// MARK: Division
262+
// ---------------------------------------------------------------------------
299263

300-
@_transparent // @_inlineable
301-
@differentiable(vjp: _vjpDivide(lhs:rhs:) where T : Differentiable)
264+
@inlinable
265+
@differentiable(vjp: _vjpDivide(lhs:rhs:) where T : Differentiable, T.TangentVector == T)
302266
public static func / (lhs: Complex, rhs: Complex) -> Complex {
303267
var a = lhs.real, b = lhs.imaginary, c = rhs.real, d = rhs.imaginary
304268
var x: T
@@ -336,66 +300,60 @@ extension Complex {
336300
return Complex(real: x, imaginary: y)
337301
}
338302

339-
@_transparent // @_inlineable
303+
@inlinable
340304
public static func /= (lhs: inout Complex, rhs: Complex) {
341305
lhs = lhs / rhs
342306
}
343307
}
344308

345309
extension Complex {
346-
func complexConjugate() -> Complex {
310+
@inlinable
311+
public func complexConjugate() -> Complex {
347312
return Complex(real: real, imaginary: -imaginary)
348313
}
349314
}
350315

351316
/// Returns the absolute value (magnitude, modulus) of `z`.
352-
@_transparent
317+
@inlinable
353318
public func abs<T>(_ z: Complex<T>) -> Complex<T> {
354319
return Complex(real: z.magnitude)
355320
}
356321

357322
extension Complex {
358-
@differentiable(vjp: _vjpAdding(real:))
323+
@inlinable
324+
@differentiable(vjp: _vjpAdding(real:) where T : Differentiable, T.TangentVector == T)
359325
public func adding(real: T) -> Complex {
360326
var c = self
361327
c.real += real
362328
return c
363329
}
364330

365-
@differentiable(vjp: _vjpSubtracting(real:))
331+
@inlinable
332+
@differentiable(vjp: _vjpSubtracting(real:) where T : Differentiable, T.TangentVector == T)
366333
public func subtracting(real: T) -> Complex {
367334
var c = self
368335
c.real -= real
369336
return c
370337
}
371338

372-
@differentiable(vjp: _vjpAdding(imaginary:))
339+
@inlinable
340+
@differentiable(vjp: _vjpAdding(imaginary:) where T : Differentiable, T.TangentVector == T)
373341
public func adding(imaginary: T) -> Complex {
374342
var c = self
375343
c.imaginary += imaginary
376344
return c
377345
}
378-
379-
@differentiable(vjp: _vjpSubtracting(imaginary:))
346+
347+
@inlinable
348+
@differentiable(vjp: _vjpSubtracting(imaginary:) where T : Differentiable, T.TangentVector == T)
380349
public func subtracting(imaginary: T) -> Complex {
381350
var c = self
382351
c.imaginary -= imaginary
383352
return c
384353
}
385354
}
386355

387-
extension Complex {
388-
@inlinable
389-
static func _vjpInit(real: T, imaginary: T) -> (Complex<T>, (Complex<T>) -> (T, T)) {
390-
// let orig: Complex<T> = Complex(real: real, imaginary: imaginary)
391-
// let pb: (Complex) -> (T, T) = { v in
392-
// return (v.real, v.imaginary)
393-
// }
394-
return (Complex(real: real, imaginary: imaginary), { v in
395-
return (v.real, v.imaginary)
396-
})
397-
}
398-
356+
extension Complex where T : Differentiable, T.TangentVector == T {
399357
@inlinable
400358
static func _vjpAdd(lhs: Complex, rhs: Complex)
401359
-> (Complex, (Complex) -> (Complex, Complex)) {

0 commit comments

Comments
 (0)