Skip to content

Commit 8efd438

Browse files
Update dpnp.fft to run on CUDA (#2332)
This PR suggests fixing the current issues with `dpnp.fft.fftn()` and `dpnp.fft.rfftn()` on CUDA and removing the skip tests for them The `incorrect result` issue for `dpnp.fft.fftn()` on cuda was because preparing the input array when `batch_fft=True` could change it to `F contiguous` array. cuFFT for correct execution requires `C contiguous` array as input. The issue with raising `Invalid strides` error is a bug in oneMath ([631](uxlfoundation/oneMath#631)). As a workaround until this issue is solved it is suggested to use swap of the last two axes if the last dimension is 1 and there are multiple axes. In this case the strides inside cuFFT are calculated correctly. Additionally updated the arguments for `test_erf` in `skipped_tests_cuda.tbl` to skip it on cuda.
1 parent 25bf9cc commit 8efd438

File tree

5 files changed

+25
-32
lines changed

5 files changed

+25
-32
lines changed

dpnp/fft/dpnp_utils_fft.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,13 +395,33 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
395395
a = dpnp.reshape(a, local_shape)
396396
index = 1
397397

398+
# cuFFT requires input arrays to be C-contiguous (row-major)
399+
# for correct execution
400+
if (
401+
dpnp.is_cuda_backend(a) and not a.flags.c_contiguous
402+
): # pragma: no cover
403+
a = dpnp.ascontiguousarray(a)
404+
405+
# w/a for cuFFT to avoid "Invalid strides" error when
406+
# the last dimension is 1 and there are multiple axes
407+
# by swapping the last two axes to correct the input.
408+
# TODO: Remove this ones the OneMath issue is resolved
409+
# https://github.com/uxlfoundation/oneMath/issues/631
410+
cufft_wa = dpnp.is_cuda_backend(a) and a.shape[-1] == 1 and len(axes) > 1
411+
if cufft_wa: # pragma: no cover
412+
a = dpnp.moveaxis(a, -1, -2)
413+
398414
a_strides = _standardize_strides_to_nonzero(a.strides, a.shape)
399415
dsc, out_strides = _commit_descriptor(
400416
a, forward, in_place, c2c, a_strides, index, batch_fft
401417
)
402418
res = _compute_result(dsc, a, out, forward, c2c, out_strides)
403419
res = _scale_result(res, a.shape, norm, forward, index)
404420

421+
# Revert swapped axes
422+
if cufft_wa: # pragma: no cover
423+
res = dpnp.moveaxis(res, -1, -2)
424+
405425
if batch_fft:
406426
tmp_shape = a_shape_orig[:-1] + (res.shape[-1],)
407427
res = dpnp.reshape(res, tmp_shape)

dpnp/tests/skipped_tests_cuda.tbl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -813,8 +813,7 @@ tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{extern
813813
# erf
814814
tests/test_special.py::test_erf
815815
tests/test_special.py::test_erf_fallback
816-
tests/test_strides.py::test_strides_erf[(10,)-int32]
817-
tests/test_strides.py::test_strides_erf[(10,)-int64]
818-
tests/test_strides.py::test_strides_erf[(10,)-float32]
819-
tests/test_strides.py::test_strides_erf[(10,)-float64]
820-
tests/test_strides.py::test_strides_erf[(10,)-None]
816+
tests/test_strides.py::test_erf[int32]
817+
tests/test_strides.py::test_erf[int64]
818+
tests/test_strides.py::test_erf[float32]
819+
tests/test_strides.py::test_erf[float64]

dpnp/tests/test_fft.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
get_all_dtypes,
1515
get_complex_dtypes,
1616
get_float_dtypes,
17-
is_cuda_device,
1817
)
1918

2019

@@ -443,11 +442,6 @@ def setup_method(self):
443442
@pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"])
444443
@pytest.mark.parametrize("order", ["C", "F"])
445444
def test_fftn(self, dtype, axes, norm, order):
446-
if is_cuda_device():
447-
if order == "C" and axes == (0, 1, 2):
448-
pass
449-
else:
450-
pytest.skip("SAT-7587")
451445
a_np = generate_random_numpy_array((2, 3, 4, 5), dtype, order)
452446
a = dpnp.array(a_np)
453447

@@ -482,9 +476,6 @@ def test_fftn_repeated_axes(self, axes):
482476
@pytest.mark.parametrize("axes", [(2, 3, 3, 2), (0, 0, 3, 3)])
483477
@pytest.mark.parametrize("s", [(5, 4, 3, 3), (7, 8, 10, 9)])
484478
def test_fftn_repeated_axes_with_s(self, axes, s):
485-
if is_cuda_device():
486-
if axes == (0, 0, 3, 3) and s == (7, 8, 10, 9):
487-
pytest.skip("SAT-7587")
488479
a_np = generate_random_numpy_array((2, 3, 4, 5), dtype=numpy.complex64)
489480
a = dpnp.array(a_np)
490481

@@ -504,11 +495,6 @@ def test_fftn_repeated_axes_with_s(self, axes, s):
504495
@pytest.mark.parametrize("axes", [(0, 1, 2, 3), (1, 2, 1, 2), (2, 2, 2, 3)])
505496
@pytest.mark.parametrize("s", [(2, 3, 4, 5), (5, 4, 7, 8), (2, 5, 1, 2)])
506497
def test_fftn_out(self, axes, s):
507-
if is_cuda_device():
508-
if axes == (0, 1, 2, 3):
509-
pytest.skip("SAT-7587")
510-
elif s == (2, 5, 1, 2) and axes in [(1, 2, 1, 2), (2, 2, 2, 3)]:
511-
pytest.skip("SAT-7587")
512498
a_np = generate_random_numpy_array((2, 3, 4, 5), dtype=numpy.complex64)
513499
a = dpnp.array(a_np)
514500

@@ -1082,9 +1068,6 @@ def test_rfftn_repeated_axes_with_s(self, axes, s):
10821068
@pytest.mark.parametrize("axes", [(0, 1, 2, 3), (1, 2, 1, 2), (2, 2, 2, 3)])
10831069
@pytest.mark.parametrize("s", [(2, 3, 4, 5), (5, 6, 7, 9), (2, 5, 1, 2)])
10841070
def test_rfftn_out(self, axes, s):
1085-
if is_cuda_device():
1086-
if axes == (0, 1, 2, 3) and s == (2, 5, 1, 2):
1087-
pytest.skip("SAT-7587")
10881071
x = numpy.random.uniform(-10, 10, 120)
10891072
a_np = numpy.array(x, dtype=numpy.float32).reshape(2, 3, 4, 5)
10901073
a = dpnp.asarray(a_np)

dpnp/tests/third_party/cupy/fft_tests/test_fft.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
import dpnp as cupy
8-
from dpnp.tests.helper import has_support_aspect64, is_cuda_device
8+
from dpnp.tests.helper import has_support_aspect64
99
from dpnp.tests.third_party.cupy import testing
1010
from dpnp.tests.third_party.cupy.testing._loops import _wraps_partial
1111

@@ -413,8 +413,6 @@ class TestFft2:
413413
type_check=has_support_aspect64(),
414414
)
415415
def test_fft2(self, xp, dtype, order, enable_nd):
416-
if is_cuda_device() and self.shape == (2, 3, 4, 5):
417-
pytest.skip("SAT-7587")
418416
# assert config.enable_nd_planning == enable_nd
419417
a = testing.shaped_random(self.shape, xp, dtype)
420418
if order == "F":
@@ -442,8 +440,6 @@ def test_fft2(self, xp, dtype, order, enable_nd):
442440
type_check=has_support_aspect64(),
443441
)
444442
def test_ifft2(self, xp, dtype, order, enable_nd):
445-
if is_cuda_device() and self.shape == (2, 3, 4, 5):
446-
pytest.skip("SAT-7587")
447443
# assert config.enable_nd_planning == enable_nd
448444
a = testing.shaped_random(self.shape, xp, dtype)
449445
if order == "F":
@@ -507,8 +503,6 @@ class TestFftn:
507503
type_check=has_support_aspect64(),
508504
)
509505
def test_fftn(self, xp, dtype, order, enable_nd):
510-
if is_cuda_device() and self.shape == (2, 3, 4, 5):
511-
pytest.skip("SAT-7587")
512506
# assert config.enable_nd_planning == enable_nd
513507
a = testing.shaped_random(self.shape, xp, dtype)
514508
if order == "F":
@@ -536,8 +530,6 @@ def test_fftn(self, xp, dtype, order, enable_nd):
536530
type_check=has_support_aspect64(),
537531
)
538532
def test_ifftn(self, xp, dtype, order, enable_nd):
539-
if is_cuda_device() and self.shape == (2, 3, 4, 5):
540-
pytest.skip("SAT-7587")
541533
# assert config.enable_nd_planning == enable_nd
542534
a = testing.shaped_random(self.shape, xp, dtype)
543535
if order == "F":

dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dpnp.tests.helper import (
88
has_support_aspect64,
99
is_cpu_device,
10-
is_cuda_device,
1110
)
1211
from dpnp.tests.third_party.cupy import testing
1312
from dpnp.tests.third_party.cupy.testing import _condition

0 commit comments

Comments
 (0)