Skip to content

Commit b0cc135

Browse files
ndgrigorianoleksandr-pavlyk
authored andcommitted
Adds tests for unique functions on strided inputs
1 parent 81610d6 commit b0cc135

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

dpctl/tests/test_usm_ndarray_unique.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ def test_unique_values(dtype):
5353
assert dpt.all(uv == dpt.arange(2, dtype=dtype))
5454

5555

56+
def test_unique_values_strided():
57+
get_queue_or_skip()
58+
59+
n, m = 1000, 20
60+
inp = dpt.ones((n, m), dtype="i4", order="F")
61+
inp[:, ::2] = 0
62+
63+
uv = dpt.unique_values(inp)
64+
assert dpt.all(uv == dpt.arange(2, dtype="i4"))
65+
66+
5667
@pytest.mark.parametrize(
5768
"dtype",
5869
[
@@ -86,6 +97,18 @@ def test_unique_counts(dtype):
8697
assert dpt.all(uv_counts == dpt.full(2, n, dtype=uv_counts.dtype))
8798

8899

100+
def test_unique_counts_strided():
101+
get_queue_or_skip()
102+
103+
n, m = 1000, 20
104+
inp = dpt.ones((n, m), dtype="i4", order="F")
105+
inp[:, ::2] = 0
106+
107+
uv, uv_counts = dpt.unique_counts(inp)
108+
assert dpt.all(uv == dpt.arange(2, dtype="i4"))
109+
assert dpt.all(uv_counts == dpt.full(2, n / 2 * m, dtype=uv_counts.dtype))
110+
111+
89112
@pytest.mark.parametrize(
90113
"dtype",
91114
[
@@ -120,6 +143,19 @@ def test_unique_inverse(dtype):
120143
assert inp.shape == inv.shape
121144

122145

146+
def test_unique_inverse_strided():
147+
get_queue_or_skip()
148+
149+
n, m = 1000, 20
150+
inp = dpt.ones((n, m), dtype="i4", order="F")
151+
inp[:, ::2] = 0
152+
153+
uv, inv = dpt.unique_inverse(inp)
154+
assert dpt.all(uv == dpt.arange(2, dtype="i4"))
155+
assert dpt.all(inp == uv[inv])
156+
assert inp.shape == inv.shape
157+
158+
123159
@pytest.mark.parametrize(
124160
"dtype",
125161
[
@@ -156,6 +192,21 @@ def test_unique_all(dtype):
156192
assert dpt.all(uv_counts == dpt.full(2, n, dtype=uv_counts.dtype))
157193

158194

195+
def test_unique_all_strided():
196+
get_queue_or_skip()
197+
198+
n, m = 1000, 20
199+
inp = dpt.ones((n, m), dtype="i4", order="F")
200+
inp[:, ::2] = 0
201+
202+
uv, ind, inv, uv_counts = dpt.unique_all(inp)
203+
assert dpt.all(uv == dpt.arange(2, dtype="i4"))
204+
assert dpt.all(uv == dpt.reshape(inp, -1)[ind])
205+
assert dpt.all(inp == uv[inv])
206+
assert inp.shape == inv.shape
207+
assert dpt.all(uv_counts == dpt.full(2, n / 2 * m, dtype=uv_counts.dtype))
208+
209+
159210
def test_set_functions_empty_input():
160211
get_queue_or_skip()
161212
x = dpt.ones((10, 0, 1), dtype="i4")

0 commit comments

Comments
 (0)