14
14
float32 ,
15
15
complex64 ,
16
16
)
17
- from ._array_object import Array , CPU_DEVICE
17
+ from ._array_object import Array , ALL_DEVICES
18
18
from ._data_type_functions import astype
19
19
from ._flags import requires_extension
20
20
@@ -36,7 +36,7 @@ def fft(
36
36
"""
37
37
if x .dtype not in _complex_floating_dtypes :
38
38
raise TypeError ("Only complex floating-point dtypes are allowed in fft" )
39
- res = Array ._new (np .fft .fft (x ._array , n = n , axis = axis , norm = norm ))
39
+ res = Array ._new (np .fft .fft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
40
40
# Note: np.fft functions improperly upcast float32 and complex64 to
41
41
# complex128
42
42
if x .dtype == complex64 :
@@ -59,7 +59,7 @@ def ifft(
59
59
"""
60
60
if x .dtype not in _complex_floating_dtypes :
61
61
raise TypeError ("Only complex floating-point dtypes are allowed in ifft" )
62
- res = Array ._new (np .fft .ifft (x ._array , n = n , axis = axis , norm = norm ))
62
+ res = Array ._new (np .fft .ifft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
63
63
# Note: np.fft functions improperly upcast float32 and complex64 to
64
64
# complex128
65
65
if x .dtype == complex64 :
@@ -82,7 +82,7 @@ def fftn(
82
82
"""
83
83
if x .dtype not in _complex_floating_dtypes :
84
84
raise TypeError ("Only complex floating-point dtypes are allowed in fftn" )
85
- res = Array ._new (np .fft .fftn (x ._array , s = s , axes = axes , norm = norm ))
85
+ res = Array ._new (np .fft .fftn (x ._array , s = s , axes = axes , norm = norm ), device = x . device )
86
86
# Note: np.fft functions improperly upcast float32 and complex64 to
87
87
# complex128
88
88
if x .dtype == complex64 :
@@ -105,7 +105,7 @@ def ifftn(
105
105
"""
106
106
if x .dtype not in _complex_floating_dtypes :
107
107
raise TypeError ("Only complex floating-point dtypes are allowed in ifftn" )
108
- res = Array ._new (np .fft .ifftn (x ._array , s = s , axes = axes , norm = norm ))
108
+ res = Array ._new (np .fft .ifftn (x ._array , s = s , axes = axes , norm = norm ), device = x . device )
109
109
# Note: np.fft functions improperly upcast float32 and complex64 to
110
110
# complex128
111
111
if x .dtype == complex64 :
@@ -128,7 +128,7 @@ def rfft(
128
128
"""
129
129
if x .dtype not in _real_floating_dtypes :
130
130
raise TypeError ("Only real floating-point dtypes are allowed in rfft" )
131
- res = Array ._new (np .fft .rfft (x ._array , n = n , axis = axis , norm = norm ))
131
+ res = Array ._new (np .fft .rfft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
132
132
# Note: np.fft functions improperly upcast float32 and complex64 to
133
133
# complex128
134
134
if x .dtype == float32 :
@@ -151,7 +151,7 @@ def irfft(
151
151
"""
152
152
if x .dtype not in _complex_floating_dtypes :
153
153
raise TypeError ("Only complex floating-point dtypes are allowed in irfft" )
154
- res = Array ._new (np .fft .irfft (x ._array , n = n , axis = axis , norm = norm ))
154
+ res = Array ._new (np .fft .irfft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
155
155
# Note: np.fft functions improperly upcast float32 and complex64 to
156
156
# complex128
157
157
if x .dtype == complex64 :
@@ -174,7 +174,7 @@ def rfftn(
174
174
"""
175
175
if x .dtype not in _real_floating_dtypes :
176
176
raise TypeError ("Only real floating-point dtypes are allowed in rfftn" )
177
- res = Array ._new (np .fft .rfftn (x ._array , s = s , axes = axes , norm = norm ))
177
+ res = Array ._new (np .fft .rfftn (x ._array , s = s , axes = axes , norm = norm ), device = x . device )
178
178
# Note: np.fft functions improperly upcast float32 and complex64 to
179
179
# complex128
180
180
if x .dtype == float32 :
@@ -197,7 +197,7 @@ def irfftn(
197
197
"""
198
198
if x .dtype not in _complex_floating_dtypes :
199
199
raise TypeError ("Only complex floating-point dtypes are allowed in irfftn" )
200
- res = Array ._new (np .fft .irfftn (x ._array , s = s , axes = axes , norm = norm ))
200
+ res = Array ._new (np .fft .irfftn (x ._array , s = s , axes = axes , norm = norm ), device = x . device )
201
201
# Note: np.fft functions improperly upcast float32 and complex64 to
202
202
# complex128
203
203
if x .dtype == complex64 :
@@ -220,7 +220,7 @@ def hfft(
220
220
"""
221
221
if x .dtype not in _complex_floating_dtypes :
222
222
raise TypeError ("Only complex floating-point dtypes are allowed in hfft" )
223
- res = Array ._new (np .fft .hfft (x ._array , n = n , axis = axis , norm = norm ))
223
+ res = Array ._new (np .fft .hfft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
224
224
# Note: np.fft functions improperly upcast float32 and complex64 to
225
225
# complex128
226
226
if x .dtype == complex64 :
@@ -243,7 +243,7 @@ def ihfft(
243
243
"""
244
244
if x .dtype not in _real_floating_dtypes :
245
245
raise TypeError ("Only real floating-point dtypes are allowed in ihfft" )
246
- res = Array ._new (np .fft .ihfft (x ._array , n = n , axis = axis , norm = norm ))
246
+ res = Array ._new (np .fft .ihfft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
247
247
# Note: np.fft functions improperly upcast float32 and complex64 to
248
248
# complex128
249
249
if x .dtype == float32 :
@@ -257,9 +257,9 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar
257
257
258
258
See its docstring for more information.
259
259
"""
260
- if device not in [ CPU_DEVICE , None ] :
260
+ if device not in ALL_DEVICES :
261
261
raise ValueError (f"Unsupported device { device !r} " )
262
- return Array ._new (np .fft .fftfreq (n , d = d ))
262
+ return Array ._new (np .fft .fftfreq (n , d = d ), device = device )
263
263
264
264
@requires_extension ('fft' )
265
265
def rfftfreq (n : int , / , * , d : float = 1.0 , device : Optional [Device ] = None ) -> Array :
@@ -268,9 +268,9 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A
268
268
269
269
See its docstring for more information.
270
270
"""
271
- if device not in [ CPU_DEVICE , None ] :
271
+ if device not in ALL_DEVICES :
272
272
raise ValueError (f"Unsupported device { device !r} " )
273
- return Array ._new (np .fft .rfftfreq (n , d = d ))
273
+ return Array ._new (np .fft .rfftfreq (n , d = d ), device = device )
274
274
275
275
@requires_extension ('fft' )
276
276
def fftshift (x : Array , / , * , axes : Union [int , Sequence [int ]] = None ) -> Array :
@@ -281,7 +281,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
281
281
"""
282
282
if x .dtype not in _floating_dtypes :
283
283
raise TypeError ("Only floating-point dtypes are allowed in fftshift" )
284
- return Array ._new (np .fft .fftshift (x ._array , axes = axes ))
284
+ return Array ._new (np .fft .fftshift (x ._array , axes = axes ), device = x . device )
285
285
286
286
@requires_extension ('fft' )
287
287
def ifftshift (x : Array , / , * , axes : Union [int , Sequence [int ]] = None ) -> Array :
@@ -292,7 +292,7 @@ def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
292
292
"""
293
293
if x .dtype not in _floating_dtypes :
294
294
raise TypeError ("Only floating-point dtypes are allowed in ifftshift" )
295
- return Array ._new (np .fft .ifftshift (x ._array , axes = axes ))
295
+ return Array ._new (np .fft .ifftshift (x ._array , axes = axes ), device = x . device )
296
296
297
297
__all__ = [
298
298
"fft" ,
0 commit comments