Skip to content

Commit a50c806

Browse files
committed
Update bincount tests
1 parent 0579afa commit a50c806

File tree

1 file changed

+50
-44
lines changed

1 file changed

+50
-44
lines changed

dpnp/tests/test_histogram.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212

1313
from .helper import (
1414
assert_dtype_allclose,
15+
generate_random_numpy_array,
1516
get_abs_array,
1617
get_all_dtypes,
18+
get_complex_dtypes,
19+
get_float_complex_dtypes,
1720
get_float_dtypes,
1821
get_integer_dtypes,
1922
has_support_aspect64,
@@ -532,34 +535,29 @@ def test_range(self, range, dtype):
532535

533536

534537
class TestBincount:
535-
@pytest.mark.parametrize("dtype", get_integer_dtypes())
536-
def test_rand_data(self, dtype):
537-
n = 100
538-
upper_bound = 10 if dtype != dpnp.bool_ else 1
539-
v = numpy.random.randint(0, upper_bound, size=n, dtype=dtype)
538+
@pytest.mark.parametrize("dt", get_integer_dtypes() + [numpy.bool_])
539+
def test_rand_data(self, dt):
540+
v = generate_random_numpy_array(100, dtype=dt, low=0)
540541
iv = dpnp.array(v)
541542

542-
if numpy.issubdtype(dtype, numpy.uint64):
543-
# discussed in numpy issue 17760
544-
assert_raises(TypeError, numpy.bincount, v)
545-
assert_raises(ValueError, dpnp.bincount, iv)
546-
else:
547-
expected_hist = numpy.bincount(v)
548-
result_hist = dpnp.bincount(iv)
549-
assert_array_equal(result_hist, expected_hist)
543+
if numpy.issubdtype(dt, numpy.uint64) and numpy_version() < "2.2.4":
544+
v = v.astype(numpy.int64)
550545

551-
@pytest.mark.parametrize("dtype", get_integer_dtypes())
552-
def test_arange_data(self, dtype):
553-
v = numpy.arange(100).astype(dtype)
546+
expected_hist = numpy.bincount(v)
547+
result_hist = dpnp.bincount(iv)
548+
assert_array_equal(result_hist, expected_hist)
549+
550+
@pytest.mark.parametrize("dt", get_integer_dtypes())
551+
def test_arange_data(self, dt):
552+
v = numpy.arange(100, dtype=dt)
554553
iv = dpnp.array(v)
555554

556-
if numpy.issubdtype(dtype, numpy.uint64):
557-
assert_raises(TypeError, numpy.bincount, v)
558-
assert_raises(ValueError, dpnp.bincount, iv)
559-
else:
560-
expected_hist = numpy.bincount(v)
561-
result_hist = dpnp.bincount(iv)
562-
assert_array_equal(result_hist, expected_hist)
555+
if numpy.issubdtype(dt, numpy.uint64) and numpy_version() < "2.2.4":
556+
v = v.astype(numpy.int64)
557+
558+
expected_hist = numpy.bincount(v)
559+
result_hist = dpnp.bincount(iv)
560+
assert_array_equal(result_hist, expected_hist)
563561

564562
@pytest.mark.parametrize("xp", [numpy, dpnp])
565563
def test_negative_values(self, xp):
@@ -581,11 +579,17 @@ def test_weights_another_sycl_queue(self):
581579
dpnp.bincount(v, weights=w)
582580

583581
@pytest.mark.parametrize("xp", [numpy, dpnp])
584-
def test_weights_unsupported_dtype(self, xp):
585-
v = dpnp.arange(5)
586-
w = dpnp.arange(5, dtype=dpnp.complex64)
587-
with assert_raises(ValueError):
588-
dpnp.bincount(v, weights=w)
582+
@pytest.mark.parametrize("dt", get_float_complex_dtypes())
583+
def test_data_unsupported_dtype(self, xp, dt):
584+
v = xp.arange(5, dtype=dt)
585+
assert_raises(TypeError, xp.bincount, v)
586+
587+
@pytest.mark.parametrize("xp", [numpy, dpnp])
588+
@pytest.mark.parametrize("dt", get_complex_dtypes())
589+
def test_weights_unsupported_dtype(self, xp, dt):
590+
v = xp.arange(5)
591+
w = xp.arange(5, dtype=dt)
592+
assert_raises((TypeError, ValueError), xp.bincount, v, weights=w)
589593

590594
@pytest.mark.parametrize(
591595
"bins_count",
@@ -606,11 +610,11 @@ def test_different_bins_amount(self, bins_count):
606610
)
607611
@pytest.mark.parametrize("minlength", [0, 1, 3, 5])
608612
def test_minlength(self, array, minlength):
609-
np_a = numpy.array(array)
610-
dpnp_a = dpnp.array(array)
613+
a = numpy.array(array)
614+
ia = dpnp.array(a)
611615

612-
expected = numpy.bincount(np_a, minlength=minlength)
613-
result = dpnp.bincount(dpnp_a, minlength=minlength)
616+
expected = numpy.bincount(a, minlength=minlength)
617+
result = dpnp.bincount(ia, minlength=minlength)
614618
assert_allclose(result, expected)
615619

616620
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@@ -639,21 +643,23 @@ def test_minlength_none(self, xp):
639643
)
640644

641645
@pytest.mark.parametrize(
642-
"array", [[1, 2, 2, 1, 2, 4]], ids=["[1, 2, 2, 1, 2, 4]"]
646+
"weights",
647+
[None, [0.3, 0.5, 0, 0.7, 1.0, -0.6], [2, 2, 2, 2, 2, 2]],
648+
ids=["None", "float_data", "int_data"],
643649
)
644650
@pytest.mark.parametrize(
645-
"weights",
646-
[None, [0.3, 0.5, 0.2, 0.7, 1.0, -0.6], [2, 2, 2, 2, 2, 2]],
647-
ids=["None", "[0.3, 0.5, 0.2, 0.7, 1., -0.6]", "[2, 2, 2, 2, 2, 2]"],
651+
"dt", get_all_dtypes(no_none=True, no_complex=True)
648652
)
649-
def test_weights(self, array, weights):
650-
np_a = numpy.array(array)
651-
np_weights = numpy.array(weights) if weights is not None else weights
652-
dpnp_a = dpnp.array(array)
653-
dpnp_weights = dpnp.array(weights) if weights is not None else weights
654-
655-
expected = numpy.bincount(np_a, weights=np_weights)
656-
result = dpnp.bincount(dpnp_a, weights=dpnp_weights)
653+
def test_weights(self, weights, dt):
654+
a = numpy.array([1, 2, 2, 1, 2, 4])
655+
ia = dpnp.array(a)
656+
w = iw = None
657+
if weights is not None:
658+
w = numpy.array(weights, dtype=dt)
659+
iw = dpnp.array(w)
660+
661+
expected = numpy.bincount(a, weights=w)
662+
result = dpnp.bincount(ia, weights=iw)
657663
assert_allclose(result, expected)
658664

659665
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)