Skip to content

Commit 4ccdec8

Browse files
Workaround for cuFFT strides bug
1 parent f8d2570 commit 4ccdec8

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

dpnp/fft/dpnp_utils_fft.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,13 +400,26 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
400400
if dpnp.is_cuda_backend(a) and not a.flags.c_contiguous:
401401
a = dpnp.ascontiguousarray(a)
402402

403+
# w/a for cuFFT to avoid "Invalid strides" error when
404+
# the last dimension is 1 and there are multiple axes
405+
# by swapping the last two axes to correct the input.
406+
# TODO: Remove this ones the OneMath issue is resolved
407+
# https://github.com/uxlfoundation/oneMath/issues/631
408+
cufft_wa = dpnp.is_cuda_backend(a) and a.shape[-1] == 1 and len(axes) > 1
409+
if cufft_wa:
410+
a = dpnp.moveaxis(a, -1, -2)
411+
403412
a_strides = _standardize_strides_to_nonzero(a.strides, a.shape)
404413
dsc, out_strides = _commit_descriptor(
405414
a, forward, in_place, c2c, a_strides, index, batch_fft
406415
)
407416
res = _compute_result(dsc, a, out, forward, c2c, out_strides)
408417
res = _scale_result(res, a.shape, norm, forward, index)
409418

419+
# Revert swapped axes
420+
if cufft_wa:
421+
res = dpnp.moveaxis(res, -1, -2)
422+
410423
if batch_fft:
411424
tmp_shape = a_shape_orig[:-1] + (res.shape[-1],)
412425
res = dpnp.reshape(res, tmp_shape)

0 commit comments

Comments
 (0)