@@ -78,11 +78,13 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
78
78
shape = a_shape [index :]
79
79
strides = (0 ,) + a_strides [index :]
80
80
if c2c : # c2c FFT
81
+ assert dpnp .issubdtype (a .dtype , dpnp .complexfloating )
81
82
if a .dtype == dpnp .complex64 :
82
83
dsc = fi .Complex64Descriptor (shape )
83
84
else :
84
85
dsc = fi .Complex128Descriptor (shape )
85
86
else : # r2c/c2r FFT
87
+ assert dpnp .issubdtype (a .dtype , dpnp .inexact )
86
88
if a .dtype in [dpnp .float32 , dpnp .complex64 ]:
87
89
dsc = fi .Real32Descriptor (shape )
88
90
else :
@@ -262,12 +264,14 @@ def _copy_array(x, complex_input):
262
264
in-place FFT can be performed.
263
265
"""
264
266
dtype = x .dtype
267
+ copy_flag = False
265
268
if numpy .min (x .strides ) < 0 :
266
269
# negative stride is not allowed in OneMKL FFT
267
270
# TODO: support for negative strides will be added in the future
268
271
# versions of OneMKL, see discussion in MKLD-17597
269
272
copy_flag = True
270
- elif complex_input and not dpnp .issubdtype (dtype , dpnp .complexfloating ):
273
+
274
+ if complex_input and not dpnp .issubdtype (dtype , dpnp .complexfloating ):
271
275
# c2c/c2r FFT, if input is not complex, convert to complex
272
276
copy_flag = True
273
277
if dtype in [dpnp .float16 , dpnp .float32 ]:
@@ -279,8 +283,6 @@ def _copy_array(x, complex_input):
279
283
# float32 or float64 depending on device capabilities
280
284
copy_flag = True
281
285
dtype = map_dtype_to_device (dpnp .float64 , x .sycl_device )
282
- else :
283
- copy_flag = False
284
286
285
287
if copy_flag :
286
288
x_copy = dpnp .empty_like (x , dtype = dtype , order = "C" )
0 commit comments