1
1
import TensorFlow
2
-
3
- //
4
- // Complex.swift
5
- // NumericAnnex
6
- //
7
- // Created by Xiaodi Wu on 3/25/17.
8
- //
9
- // Note
10
- // ====
11
- //
12
- // For maximum consistency with corresponding functions in C/C++, checks for
13
- // special values in `naturalExponential()`, `squareRoot()`, trigonometric
14
- // functions, and hyperbolic functions are adapted from libc++.
15
- //
16
- // Code in libc++ is dual-licensed under the MIT and UIUC/NCSA licenses.
17
- // Copyright © 2009-2017 contributors to the LLVM/libc++ project.
18
-
19
- /// A type to represent a complex value in Cartesian form.
20
- ///
21
- /// Additional Considerations
22
- /// -------------------------
23
- ///
24
- /// Floating-point types have special values that represent infinity or NaN
25
- /// ("not a number"). Complex functions in different languages may return
26
- /// different results when working with special values.
27
- ///
28
- /// Implementations in `Complex<T>` adhere to the [C standard][std] (Annex G) as
29
- /// closely as possible with respect to special values and branch cuts.
30
- ///
31
- /// To users unfamiliar with complex functions, the principal value returned by
32
- /// some complex functions may be unexpected. For example,
33
- /// `Double.cbrt(-8) == -2`, which is the __real root__, while
34
- /// `Complex.cbrt(-8) == 2 * Complex.exp(.i * .pi / 3)`, which is the
35
- /// __principal root__.
36
- ///
37
- /// [dfn]: http://mathworld.wolfram.com/BranchCut.html
38
- /// [std]: http://www.open-std.org/JTC1/SC22/WG14/www/standards.html#9899
39
- @_fixed_layout
40
- public struct Complex < T : TensorFlowFloatingPoint > : Differentiable {
2
+ // T : FloatingPoint & Differentiable
3
+ public struct Complex < T : FloatingPoint > {
41
4
// ---------------------------------------------------------------------------
42
5
// MARK: Stored Properties
43
6
// ---------------------------------------------------------------------------
44
7
45
8
/// The real component of the complex value.
46
- // @differentiable(vjp: _vjpReal)
47
9
public var real : T
48
10
49
11
/// The imaginary component of the complex value.
50
- // @differentiable(vjp: _vjpImaginary)
51
12
public var imaginary : T
52
13
53
14
// ---------------------------------------------------------------------------
54
15
// MARK: Initializers
55
16
// ---------------------------------------------------------------------------
56
- @differentiable ( wrt: ( real, imaginary) , vjp: _vjpInit)
57
17
public init ( real: T = 0 , imaginary: T = 0 ) {
58
18
self . real = real
59
19
self . imaginary = imaginary
60
20
}
21
+ }
61
22
23
+ extension Complex : Differentiable where T : Differentiable , T. TangentVector == T {
24
+ // ---------------------------------------------------------------------------
25
+ // MARK: Differentiability
26
+ // ---------------------------------------------------------------------------
62
27
public typealias TangentVector = Complex
63
- public typealias CotangentVector = Complex
64
28
public typealias AllDifferentiableVariables = Complex
65
- public func tangentVector( from cotangent: CotangentVector ) -> TangentVector {
66
- return cotangent
67
- }
68
29
}
69
30
70
31
extension Complex {
@@ -73,7 +34,7 @@ extension Complex {
73
34
// ---------------------------------------------------------------------------
74
35
75
36
/// The imaginary unit _i_.
76
- @_transparent // @_inlineable
37
+ @inlinable
77
38
public static var i : Complex {
78
39
return Complex ( real: 0 , imaginary: 1 )
79
40
}
@@ -82,7 +43,7 @@ extension Complex {
82
43
///
83
44
/// A complex value is finite if its real and imaginary components are both
84
45
/// finite. A component is finite if it is not infinity or NaN.
85
- @_transparent // @_inlineable
46
+ @inlinable
86
47
public var isFinite : Bool {
87
48
return real. isFinite && imaginary. isFinite
88
49
}
@@ -94,7 +55,7 @@ extension Complex {
94
55
///
95
56
/// Note that `isFinite` and `isInfinite` do not form a dichotomy because NaN
96
57
/// is neither finite nor infinite.
97
- @_transparent // @_inlineable
58
+ @inlinable
98
59
public var isInfinite : Bool {
99
60
return real. isInfinite || imaginary. isInfinite
100
61
}
@@ -109,7 +70,7 @@ extension Complex {
109
70
/// test whether a value is or is not NaN.
110
71
///
111
72
/// This property is `true` for both quiet and signaling NaNs.
112
- @_transparent // @_inlineable
73
+ @inlinable
113
74
public var isNaN : Bool {
114
75
return ( real. isNaN && !imaginary. isInfinite) ||
115
76
( imaginary. isNaN && !real. isInfinite)
@@ -119,7 +80,7 @@ extension Complex {
119
80
///
120
81
/// A complex value is equal to zero if its real and imaginary components both
121
82
/// represent either `-0.0` or `+0.0`.
122
- @_transparent // @_inlineable
83
+ @inlinable
123
84
public var isZero : Bool {
124
85
return real. isZero && imaginary. isZero
125
86
}
@@ -130,7 +91,7 @@ extension Complex : ExpressibleByIntegerLiteral {
130
91
// MARK: ExpressibleByIntegerLiteral
131
92
// ---------------------------------------------------------------------------
132
93
133
- @_transparent // @_inlineable
94
+ @inlinable
134
95
public init ( integerLiteral value: Int ) {
135
96
self . real = T ( value)
136
97
self . imaginary = 0
@@ -142,7 +103,7 @@ extension Complex : CustomStringConvertible {
142
103
// MARK: CustomStringConvertible
143
104
// ---------------------------------------------------------------------------
144
105
145
- @_transparent // @_inlineable
106
+ @inlinable
146
107
public var description : String {
147
108
return real. isNaN && real. sign == . minus
148
109
// At present, -NaN is described as "nan", which is acceptable for real
@@ -162,7 +123,7 @@ extension Complex : Equatable {
162
123
// MARK: Equatable
163
124
// ---------------------------------------------------------------------------
164
125
165
- @_transparent // @_inlineable
126
+ @inlinable
166
127
public static func == ( lhs: Complex , rhs: Complex ) -> Bool {
167
128
return lhs. real == rhs. real && lhs. imaginary == rhs. imaginary
168
129
}
@@ -173,29 +134,29 @@ extension Complex : AdditiveArithmetic {
173
134
// MARK: AdditiveArithmetic
174
135
// ---------------------------------------------------------------------------
175
136
176
- @_transparent // @_inlineable
177
- @differentiable ( vjp: _vjpAdd ( lhs: rhs: ) where T : Differentiable)
137
+ @inlinable
138
+ @differentiable ( vjp: _vjpAdd ( lhs: rhs: ) where T : Differentiable, T . TangentVector == T )
178
139
public static func + ( lhs: Complex , rhs: Complex ) -> Complex {
179
140
var lhs = lhs
180
141
lhs += rhs
181
142
return lhs
182
143
}
183
144
184
- @_transparent // @_inlineable
145
+ @inlinable
185
146
public static func += ( lhs: inout Complex , rhs: Complex ) {
186
147
lhs. real += rhs. real
187
148
lhs. imaginary += rhs. imaginary
188
149
}
189
150
190
- @_transparent // @_inlineable
191
- @differentiable ( vjp: _vjpSubtract ( lhs: rhs: ) where T : Differentiable)
151
+ @inlinable
152
+ @differentiable ( vjp: _vjpSubtract ( lhs: rhs: ) where T : Differentiable, T . TangentVector == T )
192
153
public static func - ( lhs: Complex , rhs: Complex ) -> Complex {
193
154
var lhs = lhs
194
155
lhs -= rhs
195
156
return lhs
196
157
}
197
158
198
- @_transparent // @_inlineable
159
+ @inlinable
199
160
public static func -= ( lhs: inout Complex , rhs: Complex ) {
200
161
lhs. real -= rhs. real
201
162
lhs. imaginary -= rhs. imaginary
@@ -213,8 +174,8 @@ extension Complex : Numeric {
213
174
self . imaginary = 0
214
175
}
215
176
216
- @_transparent // @_inlineable
217
- @differentiable ( vjp: _vjpMultiply ( lhs: rhs: ) where T : Differentiable)
177
+ @inlinable
178
+ @differentiable ( vjp: _vjpMultiply ( lhs: rhs: ) where T : Differentiable, T . TangentVector == T )
218
179
public static func * ( lhs: Complex , rhs: Complex ) -> Complex {
219
180
var a = lhs. real, b = lhs. imaginary, c = rhs. real, d = rhs. imaginary
220
181
let ac = a * c, bd = b * d, ad = a * d, bc = b * c
@@ -259,12 +220,12 @@ extension Complex : Numeric {
259
220
return Complex ( real: x, imaginary: y)
260
221
}
261
222
262
- @_transparent // @_inlineable
223
+ @inlinable
263
224
public static func *= ( lhs: inout Complex , rhs: Complex ) {
264
225
lhs = lhs * rhs
265
226
}
266
227
267
- @_transparent // @_inlineable
228
+ @inlinable
268
229
public var magnitude : T {
269
230
var x = abs ( real)
270
231
var y = abs ( imaginary)
@@ -282,23 +243,26 @@ extension Complex : SignedNumeric {
282
243
// MARK: SignedNumeric
283
244
// ---------------------------------------------------------------------------
284
245
285
- @_transparent // @_inlineable
286
- @differentiable ( vjp: _vjpNegate where T : Differentiable)
246
+ @inlinable
247
+ @differentiable ( vjp: _vjpNegate where T : Differentiable, T . TangentVector == T )
287
248
public static prefix func - ( operand: Complex ) -> Complex {
288
249
return Complex ( real: - operand. real, imaginary: - operand. imaginary)
289
250
}
290
251
291
- @_transparent // @_inlineable
252
+ @inlinable
292
253
public mutating func negate( ) {
293
254
real. negate ( )
294
255
imaginary. negate ( )
295
256
}
296
257
}
297
258
298
259
extension Complex {
260
+ // ---------------------------------------------------------------------------
261
+ // MARK: Division
262
+ // ---------------------------------------------------------------------------
299
263
300
- @_transparent // @_inlineable
301
- @differentiable ( vjp: _vjpDivide ( lhs: rhs: ) where T : Differentiable)
264
+ @inlinable
265
+ @differentiable ( vjp: _vjpDivide ( lhs: rhs: ) where T : Differentiable, T . TangentVector == T )
302
266
public static func / ( lhs: Complex , rhs: Complex ) -> Complex {
303
267
var a = lhs. real, b = lhs. imaginary, c = rhs. real, d = rhs. imaginary
304
268
var x : T
@@ -336,66 +300,60 @@ extension Complex {
336
300
return Complex ( real: x, imaginary: y)
337
301
}
338
302
339
- @_transparent // @_inlineable
303
+ @inlinable
340
304
public static func /= ( lhs: inout Complex , rhs: Complex ) {
341
305
lhs = lhs / rhs
342
306
}
343
307
}
344
308
345
309
extension Complex {
346
- func complexConjugate( ) -> Complex {
310
+ @inlinable
311
+ public func complexConjugate( ) -> Complex {
347
312
return Complex ( real: real, imaginary: - imaginary)
348
313
}
349
314
}
350
315
351
316
/// Returns the absolute value (magnitude, modulus) of `z`.
352
- @_transparent
317
+ @inlinable
353
318
public func abs< T> ( _ z: Complex < T > ) -> Complex < T > {
354
319
return Complex ( real: z. magnitude)
355
320
}
356
321
357
322
extension Complex {
358
- @differentiable ( vjp: _vjpAdding ( real: ) )
323
+ @inlinable
324
+ @differentiable ( vjp: _vjpAdding ( real: ) where T : Differentiable, T . TangentVector == T)
359
325
public func adding( real: T ) -> Complex {
360
326
var c = self
361
327
c. real += real
362
328
return c
363
329
}
364
330
365
- @differentiable ( vjp: _vjpSubtracting ( real: ) )
331
+ @inlinable
332
+ @differentiable ( vjp: _vjpSubtracting ( real: ) where T : Differentiable, T . TangentVector == T)
366
333
public func subtracting( real: T ) -> Complex {
367
334
var c = self
368
335
c. real -= real
369
336
return c
370
337
}
371
338
372
- @differentiable ( vjp: _vjpAdding ( imaginary: ) )
339
+ @inlinable
340
+ @differentiable ( vjp: _vjpAdding ( imaginary: ) where T : Differentiable, T . TangentVector == T)
373
341
public func adding( imaginary: T ) -> Complex {
374
342
var c = self
375
343
c. imaginary += imaginary
376
344
return c
377
345
}
378
-
379
- @differentiable ( vjp: _vjpSubtracting ( imaginary: ) )
346
+
347
+ @inlinable
348
+ @differentiable ( vjp: _vjpSubtracting ( imaginary: ) where T : Differentiable, T . TangentVector == T)
380
349
public func subtracting( imaginary: T ) -> Complex {
381
350
var c = self
382
351
c. imaginary -= imaginary
383
352
return c
384
353
}
385
354
}
386
355
387
- extension Complex {
388
- @inlinable
389
- static func _vjpInit( real: T , imaginary: T ) -> ( Complex < T > , ( Complex < T > ) -> ( T , T ) ) {
390
- // let orig: Complex<T> = Complex(real: real, imaginary: imaginary)
391
- // let pb: (Complex) -> (T, T) = { v in
392
- // return (v.real, v.imaginary)
393
- // }
394
- return ( Complex ( real: real, imaginary: imaginary) , { v in
395
- return ( v. real, v. imaginary)
396
- } )
397
- }
398
-
356
+ extension Complex where T : Differentiable , T. TangentVector == T {
399
357
@inlinable
400
358
static func _vjpAdd( lhs: Complex , rhs: Complex )
401
359
-> ( Complex , ( Complex ) -> ( Complex , Complex ) ) {
0 commit comments