Skip to content

Commit d2da986

Browse files
authored
define a new helper function for tests called get_integer_float_dtypes (#2402)
A new helper function is defined for tests, to get integer and float dtypes.
1 parent f13e9a4 commit d2da986

File tree

9 files changed

+84
-136
lines changed

9 files changed

+84
-136
lines changed

dpnp/tests/helper.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,36 @@ def get_integer_dtypes(all_int_types=False, no_unsigned=False):
325325
return dtypes
326326

327327

328+
def get_integer_float_dtypes(
329+
all_int_types=False,
330+
no_unsigned=False,
331+
no_float16=True,
332+
device=None,
333+
xfail_dtypes=None,
334+
exclude=None,
335+
):
336+
"""
337+
Build a list of integer and float types supported by DPNP.
338+
"""
339+
dtypes = get_integer_dtypes(
340+
all_int_types=all_int_types, no_unsigned=no_unsigned
341+
)
342+
dtypes += get_float_dtypes(no_float16=no_float16, device=device)
343+
344+
def mark_xfail(dtype):
345+
if xfail_dtypes is not None and dtype in xfail_dtypes:
346+
return pytest.param(dtype, marks=pytest.mark.xfail)
347+
return dtype
348+
349+
def not_excluded(dtype):
350+
if exclude is None:
351+
return True
352+
return dtype not in exclude
353+
354+
dtypes = [mark_xfail(dtype) for dtype in dtypes if not_excluded(dtype)]
355+
return dtypes
356+
357+
328358
def has_support_aspect16(device=None):
329359
"""
330360
Return True if the device supports 16-bit precision floating point operations,

dpnp/tests/test_binary_ufuncs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
get_float_complex_dtypes,
2121
get_float_dtypes,
2222
get_integer_dtypes,
23+
get_integer_float_dtypes,
2324
has_support_aspect16,
2425
numpy_version,
2526
)
@@ -141,7 +142,7 @@ def test_invalid_out(self, xp, out):
141142
@pytest.mark.parametrize("func", ["fmax", "fmin", "maximum", "minimum"])
142143
class TestBoundFuncs:
143144
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
144-
def test_out(self, func, dtype):
145+
def test_basic(self, func, dtype):
145146
a = generate_random_numpy_array(10, dtype)
146147
b = generate_random_numpy_array(10, dtype)
147148
expected = getattr(numpy, func)(a, b)
@@ -278,7 +279,7 @@ def test_invalid_out(self, xp, out):
278279

279280
@pytest.mark.parametrize("func", ["floor_divide", "remainder"])
280281
class TestFloorDivideRemainder:
281-
ALL_DTYPES = get_all_dtypes(no_none=True, no_bool=True, no_complex=True)
282+
ALL_DTYPES = get_integer_float_dtypes()
282283

283284
def do_inplace_op(self, base, other, func):
284285
if func == "floor_divide":

dpnp/tests/test_histogram.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@
1919
get_float_complex_dtypes,
2020
get_float_dtypes,
2121
get_integer_dtypes,
22+
get_integer_float_dtypes,
2223
has_support_aspect64,
2324
numpy_version,
2425
)
2526

2627

2728
class TestDigitize:
28-
@pytest.mark.parametrize(
29-
"dtype", get_all_dtypes(no_bool=True, no_complex=True)
30-
)
29+
@pytest.mark.parametrize("dtype", get_integer_float_dtypes())
3130
@pytest.mark.parametrize("right", [True, False])
3231
@pytest.mark.parametrize(
3332
"x, bins",
@@ -73,12 +72,8 @@ def test_digitize_inf(self, dtype, right):
7372
expected = numpy.digitize(x, bins, right=right)
7473
assert_dtype_allclose(result, expected)
7574

76-
@pytest.mark.parametrize(
77-
"dtype_x", get_all_dtypes(no_bool=True, no_complex=True)
78-
)
79-
@pytest.mark.parametrize(
80-
"dtype_bins", get_all_dtypes(no_bool=True, no_complex=True)
81-
)
75+
@pytest.mark.parametrize("dtype_x", get_integer_float_dtypes())
76+
@pytest.mark.parametrize("dtype_bins", get_integer_float_dtypes())
8277
@pytest.mark.parametrize("right", [True, False])
8378
def test_digitize_diff_types(self, dtype_x, dtype_bins, right):
8479
x = numpy.array([1, 2, 3, 4, 5], dtype=dtype_x)
@@ -90,9 +85,7 @@ def test_digitize_diff_types(self, dtype_x, dtype_bins, right):
9085
expected = numpy.digitize(x, bins, right=right)
9186
assert_dtype_allclose(result, expected)
9287

93-
@pytest.mark.parametrize(
94-
"dtype", get_all_dtypes(no_bool=True, no_complex=True)
95-
)
88+
@pytest.mark.parametrize("dtype", get_integer_float_dtypes())
9689
@pytest.mark.parametrize(
9790
"x, bins",
9891
[

dpnp/tests/test_linalg.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
get_all_dtypes,
2323
get_complex_dtypes,
2424
get_float_complex_dtypes,
25+
get_integer_float_dtypes,
2526
has_support_aspect64,
2627
is_cpu_device,
2728
is_cuda_device,
@@ -1409,9 +1410,7 @@ def test_einsum_tensor(self):
14091410
result = dpnp.einsum("ijij->", tensor_dp)
14101411
assert_dtype_allclose(result, expected)
14111412

1412-
@pytest.mark.parametrize(
1413-
"dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True)
1414-
)
1413+
@pytest.mark.parametrize("dtype", get_integer_float_dtypes())
14151414
def test_different_paths(self, dtype):
14161415
# Simple test, designed to exercise most specialized code paths,
14171416
# note the +0.5 for floats. This makes sure we use a float value

dpnp/tests/test_logic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
get_all_dtypes,
99
get_float_complex_dtypes,
1010
get_float_dtypes,
11+
get_integer_float_dtypes,
1112
)
1213

1314

@@ -83,7 +84,7 @@ def check_raises(func_name, exception, *args, **kwargs):
8384
check_raises(func, TypeError, [0, 1, 2, 3])
8485

8586

86-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
87+
@pytest.mark.parametrize("dtype", get_integer_float_dtypes())
8788
def test_allclose(dtype):
8889
a = numpy.random.rand(10)
8990
b = a + numpy.random.rand(10) * 1e-8
@@ -508,7 +509,7 @@ def test_infinity_sign_errors(func):
508509
getattr(dpnp, func)(x, out=out)
509510

510511

511-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
512+
@pytest.mark.parametrize("dtype", get_integer_float_dtypes())
512513
@pytest.mark.parametrize(
513514
"rtol", [1e-05, dpnp.array(1e-05), dpnp.full(10, 1e-05)]
514515
)
@@ -549,7 +550,7 @@ def test_array_equiv(a, b):
549550
assert_equal(result, expected)
550551

551552

552-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
553+
@pytest.mark.parametrize("dtype", get_integer_float_dtypes())
553554
def test_array_equiv_dtype(dtype):
554555
a = numpy.array([1, 2], dtype=dtype)
555556
b = numpy.array([1, 2], dtype=dtype)
@@ -575,7 +576,7 @@ def test_array_equiv_scalar(a):
575576
assert_equal(result, expected)
576577

577578

578-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
579+
@pytest.mark.parametrize("dtype", get_integer_float_dtypes())
579580
@pytest.mark.parametrize("equal_nan", [True, False])
580581
def test_array_equal_dtype(dtype, equal_nan):
581582
a = numpy.array([1, 2], dtype=dtype)

dpnp/tests/test_manipulation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
get_float_complex_dtypes,
2121
get_float_dtypes,
2222
get_integer_dtypes,
23+
get_integer_float_dtypes,
2324
has_support_aspect64,
2425
numpy_version,
2526
)
@@ -325,9 +326,7 @@ class TestCopyTo:
325326
]
326327
testdata += [
327328
([1, -1, 0], dtype)
328-
for dtype in get_all_dtypes(
329-
no_none=True, no_bool=True, no_complex=True, no_unsigned=True
330-
)
329+
for dtype in get_integer_float_dtypes(no_unsigned=True)
331330
]
332331
testdata += [([0.1, 0.0, -0.1], dtype) for dtype in get_float_dtypes()]
333332
testdata += [([1j, -1j, 1 - 2j], dtype) for dtype in get_complex_dtypes()]

0 commit comments

Comments
 (0)