Skip to content

Commit ec299bc

Browse files
committed
Add USM type and SYCL queue tests for compress
1 parent 81d8181 commit ec299bc

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

dpnp/tests/test_sycl_queue.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2855,3 +2855,24 @@ def test_ix(device_0, device_1):
28552855
ixgrid = dpnp.ix_(x0, x1)
28562856
assert_sycl_queue_equal(ixgrid[0].sycl_queue, x0.sycl_queue)
28572857
assert_sycl_queue_equal(ixgrid[1].sycl_queue, x1.sycl_queue)
2858+
2859+
2860+
@pytest.mark.parametrize(
2861+
"device",
2862+
valid_devices,
2863+
ids=[device.filter_string for device in valid_devices],
2864+
)
2865+
def test_compress(device):
2866+
a_np = numpy.arange(5)
2867+
a = dpnp.array(a_np, device=device)
2868+
2869+
cond_np = numpy.array([0, 1, 0])
2870+
cond = dpnp.array(cond_np, device=device)
2871+
2872+
expected = numpy.compress(cond_np, a_np, axis=None)
2873+
result = dpnp.compress(cond, a, axis=None)
2874+
assert_allclose(expected, result)
2875+
2876+
expected_queue = a.sycl_queue
2877+
result_queue = result.sycl_queue
2878+
assert_sycl_queue_equal(result_queue, expected_queue)

dpnp/tests/test_usm_type.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,3 +1746,17 @@ def test_ix(usm_type_0, usm_type_1):
17461746
ixgrid = dp.ix_(x0, x1)
17471747
assert ixgrid[0].usm_type == x0.usm_type
17481748
assert ixgrid[1].usm_type == x1.usm_type
1749+
1750+
1751+
@pytest.mark.parametrize("usm_type_a", list_of_usm_types, ids=list_of_usm_types)
1752+
@pytest.mark.parametrize(
1753+
"usm_type_cond", list_of_usm_types, ids=list_of_usm_types
1754+
)
1755+
def test_compress(usm_type_a, usm_type_cond):
1756+
a = dp.arange(5, usm_type=usm_type_a)
1757+
cond = dp.array([False, True, True], usm_type=usm_type_cond)
1758+
z = dp.compress(cond, a, axis=None)
1759+
1760+
assert a.usm_type == usm_type_a
1761+
assert cond.usm_type == usm_type_cond
1762+
assert z.usm_type == du.get_coerced_usm_type([usm_type_a, usm_type_cond])

0 commit comments

Comments
 (0)