@@ -400,13 +400,26 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
400
400
if dpnp .is_cuda_backend (a ) and not a .flags .c_contiguous :
401
401
a = dpnp .ascontiguousarray (a )
402
402
403
+ # w/a for cuFFT to avoid "Invalid strides" error when
404
+ # the last dimension is 1 and there are multiple axes
405
+ # by swapping the last two axes to correct the input.
406
+ # TODO: Remove this ones the OneMath issue is resolved
407
+ # https://github.com/uxlfoundation/oneMath/issues/631
408
+ cufft_wa = dpnp .is_cuda_backend (a ) and a .shape [- 1 ] == 1 and len (axes ) > 1
409
+ if cufft_wa :
410
+ a = dpnp .moveaxis (a , - 1 , - 2 )
411
+
403
412
a_strides = _standardize_strides_to_nonzero (a .strides , a .shape )
404
413
dsc , out_strides = _commit_descriptor (
405
414
a , forward , in_place , c2c , a_strides , index , batch_fft
406
415
)
407
416
res = _compute_result (dsc , a , out , forward , c2c , out_strides )
408
417
res = _scale_result (res , a .shape , norm , forward , index )
409
418
419
+ # Revert swapped axes
420
+ if cufft_wa :
421
+ res = dpnp .moveaxis (res , - 1 , - 2 )
422
+
410
423
if batch_fft :
411
424
tmp_shape = a_shape_orig [:- 1 ] + (res .shape [- 1 ],)
412
425
res = dpnp .reshape (res , tmp_shape )
0 commit comments