Skip to content

Commit 88911fb

Browse files
Fix FFT negative strides case (#2202)
* Fix issue with negative strides * Apply suggestions from code review Co-authored-by: Anton <[email protected]> --------- Co-authored-by: Anton <[email protected]>
1 parent 2bab446 commit 88911fb

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

dpnp/fft/dpnp_utils_fft.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,13 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
7878
shape = a_shape[index:]
7979
strides = (0,) + a_strides[index:]
8080
if c2c: # c2c FFT
81+
assert dpnp.issubdtype(a.dtype, dpnp.complexfloating)
8182
if a.dtype == dpnp.complex64:
8283
dsc = fi.Complex64Descriptor(shape)
8384
else:
8485
dsc = fi.Complex128Descriptor(shape)
8586
else: # r2c/c2r FFT
87+
assert dpnp.issubdtype(a.dtype, dpnp.inexact)
8688
if a.dtype in [dpnp.float32, dpnp.complex64]:
8789
dsc = fi.Real32Descriptor(shape)
8890
else:
@@ -262,12 +264,14 @@ def _copy_array(x, complex_input):
262264
in-place FFT can be performed.
263265
"""
264266
dtype = x.dtype
267+
copy_flag = False
265268
if numpy.min(x.strides) < 0:
266269
# negative stride is not allowed in OneMKL FFT
267270
# TODO: support for negative strides will be added in the future
268271
# versions of OneMKL, see discussion in MKLD-17597
269272
copy_flag = True
270-
elif complex_input and not dpnp.issubdtype(dtype, dpnp.complexfloating):
273+
274+
if complex_input and not dpnp.issubdtype(dtype, dpnp.complexfloating):
271275
# c2c/c2r FFT, if input is not complex, convert to complex
272276
copy_flag = True
273277
if dtype in [dpnp.float16, dpnp.float32]:
@@ -279,8 +283,6 @@ def _copy_array(x, complex_input):
279283
# float32 or float64 depending on device capabilities
280284
copy_flag = True
281285
dtype = map_dtype_to_device(dpnp.float64, x.sycl_device)
282-
else:
283-
copy_flag = False
284286

285287
if copy_flag:
286288
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")

dpnp/tests/test_fft.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,16 @@ def test_fft_validate_out(self):
378378
out = dpnp.empty((10,), dtype=dpnp.float32)
379379
assert_raises(TypeError, dpnp.fft.fft, a, out=out)
380380

381+
@pytest.mark.parametrize(
382+
"dtype", get_all_dtypes(no_none=True, no_bool=True)
383+
)
384+
def test_negative_stride(self, dtype):
385+
a = dpnp.arange(10, dtype=dtype)
386+
result = dpnp.fft.fft(a[::-1])
387+
expected = numpy.fft.fft(a.asnumpy()[::-1])
388+
389+
assert_dtype_allclose(result, expected, check_only_type_kind=True)
390+
381391

382392
class TestFft2:
383393
def setup_method(self):

0 commit comments

Comments
 (0)