Skip to content

Commit 46e6d73

Browse files
committed
Update fft tests
1 parent d8fa667 commit 46e6d73

File tree

1 file changed

+5
-54
lines changed

1 file changed

+5
-54
lines changed

dpnp/tests/third_party/cupy/fft_tests/test_fft.py

Lines changed: 5 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ def test_ifftn_orders(self, dtype, enable_nd):
889889
pass
890890

891891

892-
# @testing.with_requires("numpy>=2.0")
892+
@testing.with_requires("numpy>=2.0")
893893
@pytest.mark.usefixtures("skip_forward_backward")
894894
@testing.parameterize(
895895
*testing.product(
@@ -925,17 +925,6 @@ def test_rfft(self, xp, dtype):
925925
def test_irfft(self, xp, dtype):
926926
a = testing.shaped_random(self.shape, xp, dtype)
927927
out = xp.fft.irfft(a, n=self.n, norm=self.norm)
928-
929-
if dtype == xp.float16 and xp is cupy:
930-
# XXX: np2.0: f16 dtypes differ
931-
out = out.astype(np.float16)
932-
elif (
933-
xp is np
934-
and np.lib.NumpyVersion(np.__version__) < "2.0.0"
935-
and dtype == np.float32
936-
):
937-
out = out.astype(np.float32)
938-
939928
return out
940929

941930

@@ -1008,7 +997,7 @@ def test_rfft_error_on_wrong_plan(self, dtype):
1008997
assert "Target array size does not match the plan." in str(ex.value)
1009998

1010999

1011-
# @testing.with_requires("numpy>=2.0")
1000+
@testing.with_requires("numpy>=2.0")
10121001
@pytest.mark.usefixtures("skip_forward_backward")
10131002
@testing.parameterize(
10141003
*(
@@ -1069,13 +1058,6 @@ def test_irfft2(self, xp, dtype, order, enable_nd):
10691058

10701059
if self.s is None and self.axes in [None, (-2, -1)]:
10711060
pytest.skip("Input is not Hermitian Symmetric")
1072-
elif dtype == xp.float16 and xp is cupy:
1073-
pytest.xfail("XXX: np2.0: f16 dtypes differ")
1074-
elif (
1075-
np.lib.NumpyVersion(np.__version__) < "2.0.0"
1076-
and dtype == np.float32
1077-
):
1078-
pytest.skip("dtypes differ")
10791061

10801062
a = testing.shaped_random(self.shape, xp, dtype)
10811063
if order == "F":
@@ -1105,7 +1087,7 @@ def test_irfft2(self, dtype):
11051087
xp.fft.irfft2(a, s=self.s, axes=self.axes, norm=self.norm)
11061088

11071089

1108-
# @testing.with_requires("numpy>=2.0")
1090+
@testing.with_requires("numpy>=2.0")
11091091
@pytest.mark.usefixtures("skip_forward_backward")
11101092
@testing.parameterize(
11111093
*(
@@ -1166,13 +1148,6 @@ def test_irfftn(self, xp, dtype, order, enable_nd):
11661148

11671149
if self.s is None and self.axes in [None, (-2, -1)]:
11681150
pytest.skip("Input is not Hermitian Symmetric")
1169-
elif dtype == xp.float16 and xp is cupy:
1170-
pytest.xfail("XXX: np2.0: f16 dtypes differ")
1171-
elif (
1172-
np.lib.NumpyVersion(np.__version__) < "2.0.0"
1173-
and dtype == np.float32
1174-
):
1175-
pytest.skip("dtypes differ")
11761151

11771152
a = testing.shaped_random(self.shape, xp, dtype)
11781153
if order == "F":
@@ -1243,10 +1218,6 @@ def test_rfftn(self, xp, dtype, enable_nd):
12431218
def test_irfftn(self, xp, dtype, enable_nd):
12441219
assert config.enable_nd_planning == enable_nd
12451220
a = testing.shaped_random(self.shape, xp, dtype)
1246-
1247-
if dtype == xp.float16 and xp is cupy:
1248-
pytest.xfail("XXX: np2.0: f16 dtypes differ")
1249-
12501221
if xp is np:
12511222
return xp.fft.irfftn(a, s=self.s, axes=self.axes, norm=self.norm)
12521223

@@ -1349,7 +1320,7 @@ def test_irfftn(self, dtype):
13491320
xp.fft.irfftn(a, s=self.s, axes=self.axes, norm=self.norm)
13501321

13511322

1352-
# @testing.with_requires("numpy>=2.0")
1323+
@testing.with_requires("numpy>=2.0")
13531324
@pytest.mark.usefixtures("skip_forward_backward")
13541325
@testing.parameterize(
13551326
*testing.product(
@@ -1373,17 +1344,6 @@ class TestHfft:
13731344
def test_hfft(self, xp, dtype):
13741345
a = testing.shaped_random(self.shape, xp, dtype)
13751346
out = xp.fft.hfft(a, n=self.n, norm=self.norm)
1376-
1377-
if dtype == xp.float16 and xp is cupy:
1378-
# XXX: np2.0: f16 dtypes differ
1379-
out = out.astype(np.float16)
1380-
elif (
1381-
xp is np
1382-
and np.lib.NumpyVersion(np.__version__) < "2.0.0"
1383-
and dtype == np.float32
1384-
):
1385-
out = out.astype(np.float32)
1386-
13871347
return out
13881348

13891349
@testing.for_all_dtypes(no_complex=True)
@@ -1396,16 +1356,7 @@ def test_hfft(self, xp, dtype):
13961356
)
13971357
def test_ihfft(self, xp, dtype):
13981358
a = testing.shaped_random(self.shape, xp, dtype)
1399-
out = xp.fft.ihfft(a, n=self.n, norm=self.norm)
1400-
1401-
if (
1402-
xp is np
1403-
and np.lib.NumpyVersion(np.__version__) < "2.0.0"
1404-
and dtype == np.float32
1405-
):
1406-
out = out.astype(np.complex64)
1407-
1408-
return out
1359+
return xp.fft.ihfft(a, n=self.n, norm=self.norm)
14091360

14101361

14111362
# @testing.with_requires("numpy>=2.0")

0 commit comments

Comments
 (0)