Skip to content

Commit 3153dd8

Browse files
Add a new func to helper and fix some tests
1 parent 0ec6f44 commit 3153dd8

File tree

9 files changed

+345
-264
lines changed

9 files changed

+345
-264
lines changed

tests/helper.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import dpctl
22
import dpnp
3+
import pytest
34

45

5-
def get_all_dtypes(no_bool=False,
6-
no_float16=True,
7-
no_complex=False,
8-
no_none=False,
9-
device=None):
6+
def get_all_dtypes(
7+
no_bool=False, no_float16=True, no_complex=False, no_none=False, device=None
8+
):
109
"""
1110
Build a list of types supported by DPNP based on input flags and device capabilities.
1211
"""
@@ -37,3 +36,29 @@ def get_all_dtypes(no_bool=False,
3736
if not no_none:
3837
dtypes.append(None)
3938
return dtypes
39+
40+
41+
def skip_or_change_if_dtype_not_supported(dtype, device=None, change_dtype=False):
42+
"""
43+
The function to check input type supported in DPNP based on the device capabilities.
44+
"""
45+
46+
dev = dpctl.select_default_device() if device is None else device
47+
dev_has_dp = dev.has_aspect_fp64
48+
if dtype in [dpnp.float32, dpnp.float64] and dev_has_dp is False:
49+
if change_dtype:
50+
return dpnp.float32
51+
else:
52+
pytest.skip(
53+
f"{dev.name} does not support double precision floating point types"
54+
)
55+
dev_has_hp = dev.has_aspect_fp16
56+
if dtype in [dpnp.complex64, dpnp.complex128] and dev_has_hp is False:
57+
if change_dtype:
58+
return dpnp.complex64
59+
else:
60+
pytest.skip(
61+
f"{dev.name} does not support double precision floating point types"
62+
)
63+
64+
return dtype

tests/test_absolute.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
import numpy
77

88

9-
@pytest.mark.parametrize(
10-
"dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True)
11-
)
9+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
1210
def test_abs(dtype):
1311
a = numpy.array([1, 0, 2, -3, -1, 2, 21, -9], dtype=dtype)
1412
ia = inp.array(a)

tests/test_arithmetic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import unittest
22
import pytest
3+
from .helper import skip_or_change_if_dtype_not_supported
34

45
from tests.third_party.cupy import testing
56

67

78
class TestArithmetic(unittest.TestCase):
8-
99
@testing.for_float_dtypes()
1010
@testing.numpy_cupy_allclose()
1111
def test_modf_part1(self, xp, dtype):
12+
dtype = skip_or_change_if_dtype_not_supported(dtype, change_dtype=True)
1213
a = xp.array([-2.5, -1.5, -0.5, 0, 0.5, 1.5, 2.5], dtype=dtype)
1314
b, _ = xp.modf(a)
1415

@@ -17,6 +18,7 @@ def test_modf_part1(self, xp, dtype):
1718
@testing.for_float_dtypes()
1819
@testing.numpy_cupy_allclose()
1920
def test_modf_part2(self, xp, dtype):
21+
dtype = skip_or_change_if_dtype_not_supported(dtype, change_dtype=True)
2022
a = xp.array([-2.5, -1.5, -0.5, 0, 0.5, 1.5, 2.5], dtype=dtype)
2123
_, c = xp.modf(a)
2224

@@ -26,19 +28,22 @@ def test_modf_part2(self, xp, dtype):
2628
@testing.for_float_dtypes()
2729
@testing.numpy_cupy_allclose()
2830
def test_nanprod(self, xp, dtype):
31+
dtype = skip_or_change_if_dtype_not_supported(dtype, change_dtype=True)
2932
a = xp.array([-2.5, -1.5, xp.nan, 10.5, 1.5, xp.nan], dtype=dtype)
3033
return xp.nanprod(a)
3134

3235
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
3336
@testing.for_float_dtypes()
3437
@testing.numpy_cupy_allclose()
3538
def test_nansum(self, xp, dtype):
39+
dtype = skip_or_change_if_dtype_not_supported(dtype, change_dtype=True)
3640
a = xp.array([-2.5, -1.5, xp.nan, 10.5, 1.5, xp.nan], dtype=dtype)
3741
return xp.nansum(a)
3842

3943
@testing.for_float_dtypes()
4044
@testing.numpy_cupy_allclose()
4145
def test_remainder(self, xp, dtype):
46+
dtype = skip_or_change_if_dtype_not_supported(dtype, change_dtype=True)
4247
a = xp.array([5, -3, -2, -1, -5], dtype=dtype)
4348
b = xp.full(a.size, 3, dtype=dtype)
4449

0 commit comments

Comments
 (0)