Skip to content

Commit 2619205

Browse files
committed
address new comments
1 parent b913df8 commit 2619205

File tree

3 files changed

+13
-21
lines changed

3 files changed

+13
-21
lines changed

dpnp/tests/test_histogram.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ class TestDigitize:
4646
)
4747
def test_digitize(self, x, bins, dtype, right):
4848
x = get_abs_array(x, dtype)
49-
if numpy.issubdtype(dtype, numpy.unsignedinteger) and bins[0] == -4:
50-
# bins should be monotonically increasing, cannot use get_abs_array
51-
bins = numpy.array([0, 2, 4, 6, 8])
49+
if numpy.issubdtype(dtype, numpy.unsignedinteger):
50+
min_bin = bins.min()
51+
if min_bin < 0:
52+
# bins should be monotonically increasing, cannot use get_abs_array
53+
bins -= min_bin
5254
bins = bins.astype(dtype)
5355
x_dp = dpnp.array(x)
5456
bins_dp = dpnp.array(bins)

dpnp/tests/test_indexing.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,7 @@ def test_extract_diff_dtypes(self, a_dt, cond_dt):
137137

138138
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
139139
def test_extract_list_cond(self, a_dt):
140-
x = [-2, -1, 0, 1, 2, 3]
141-
if numpy.issubdtype(a_dt, numpy.unsignedinteger):
142-
x = numpy.abs(x)
143-
a = numpy.array(x, dtype=a_dt)
140+
a = get_abs_array([-2, -1, 0, 1, 2, 3], a_dt)
144141
cond = [1, -1, 2, 0, -2, 3]
145142
ia = dpnp.array(a)
146143

@@ -393,10 +390,7 @@ def test_0d(self, val):
393390

394391
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
395392
def test_1d(self, dtype):
396-
x = [1, 0, 2, -1, 0, 0, 8]
397-
if numpy.issubdtype(dtype, numpy.unsignedinteger):
398-
x = numpy.abs(x)
399-
a = numpy.array(x, dtype=dtype)
393+
a = get_abs_array([1, 0, 2, -1, 0, 0, 8], dtype)
400394
ia = dpnp.array(a)
401395

402396
np_res = numpy.nonzero(a)

dpnp/tests/test_mathematical.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
get_integer_dtypes,
3333
has_support_aspect16,
3434
has_support_aspect64,
35+
numpy_version,
3536
)
3637
from .test_umath import (
3738
_get_numpy_arrays_1in_1out,
@@ -1408,7 +1409,7 @@ class TestLdexp:
14081409
@pytest.mark.parametrize("exp_dt", get_integer_dtypes())
14091410
def test_basic(self, mant_dt, exp_dt):
14101411
if (
1411-
numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0"
1412+
numpy_version() < "2.0.0"
14121413
and exp_dt == numpy.int64
14131414
and numpy.dtype("l") != numpy.int64
14141415
):
@@ -1421,9 +1422,7 @@ def test_basic(self, mant_dt, exp_dt):
14211422
if dpnp.issubdtype(exp_dt, dpnp.uint64):
14221423
assert_raises(ValueError, dpnp.ldexp, imant, iexp)
14231424
assert_raises(TypeError, numpy.ldexp, mant, exp)
1424-
elif numpy.lib.NumpyVersion(
1425-
numpy.__version__
1426-
) < "2.0.0" and dpnp.issubdtype(exp_dt, dpnp.uint32):
1425+
elif numpy_version() < "2.0.0" and dpnp.issubdtype(exp_dt, dpnp.uint32):
14271426
# For this special case, NumPy < "2.0.0" raises an error on Windows
14281427
result = dpnp.ldexp(imant, iexp)
14291428
expected = numpy.ldexp(mant, exp.astype(numpy.int32))
@@ -2130,7 +2129,7 @@ def test_zeros(self, dt):
21302129

21312130
result = dpnp.spacing(ia)
21322131
expected = numpy.spacing(a)
2133-
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
2132+
if numpy_version() < "2.0.0":
21342133
assert_equal(result, expected)
21352134
else:
21362135
# numpy.spacing(-0.0) == numpy.spacing(0.0), i.e. NumPy returns
@@ -2193,7 +2192,7 @@ def test_complex(self, xp):
21932192

21942193
class TestTrapezoid:
21952194
def get_numpy_func(self):
2196-
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
2195+
if numpy_version() < "2.0.0":
21972196
# `trapz` is deprecated in NumPy 2.0
21982197
return numpy.trapz
21992198
return numpy.trapezoid
@@ -2753,10 +2752,7 @@ def test_out(self, func_params, dtype):
27532752
# NumPy < 2.0.0 while output has the dtype of input for NumPy >= 2.0.0
27542753
# (dpnp follows the latter behavior except for boolean dtype where it
27552754
# returns int8)
2756-
if (
2757-
numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0"
2758-
or dtype == numpy.bool
2759-
):
2755+
if numpy_version() < "2.0.0" or dtype == numpy.bool:
27602756
check_type = False
27612757
else:
27622758
check_type = True

0 commit comments

Comments
 (0)