Skip to content

Commit 94a2ebf

Browse files
committed
Tests for strided 1D data input for unique functions
1 parent 938566c commit 94a2ebf

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

dpctl/tests/test_usm_ndarray_unique.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ def test_unique_values_strided():
6363
uv = dpt.unique_values(inp)
6464
assert dpt.all(uv == dpt.arange(2, dtype="i4"))
6565

66+
inp = dpt.reshape(inp, -1)
67+
inp = dpt.flip(dpt.reshape(inp, -1))
68+
69+
uv = dpt.unique_values(inp)
70+
assert dpt.all(uv == dpt.arange(2, dtype="i4"))
71+
6672

6773
@pytest.mark.parametrize(
6874
"dtype",
@@ -108,6 +114,12 @@ def test_unique_counts_strided():
108114
assert dpt.all(uv == dpt.arange(2, dtype="i4"))
109115
assert dpt.all(uv_counts == dpt.full(2, n / 2 * m, dtype=uv_counts.dtype))
110116

117+
inp = dpt.flip(dpt.reshape(inp, -1))
118+
119+
uv, uv_counts = dpt.unique_counts(inp)
120+
assert dpt.all(uv == dpt.arange(2, dtype="i4"))
121+
assert dpt.all(uv_counts == dpt.full(2, n / 2 * m, dtype=uv_counts.dtype))
122+
111123

112124
@pytest.mark.parametrize(
113125
"dtype",
@@ -155,6 +167,13 @@ def test_unique_inverse_strided():
155167
assert dpt.all(inp == uv[inv])
156168
assert inp.shape == inv.shape
157169

170+
inp = dpt.flip(dpt.reshape(inp, -1))
171+
172+
uv, inv = dpt.unique_inverse(inp)
173+
assert dpt.all(uv == dpt.arange(2, dtype="i4"))
174+
assert dpt.all(inp == uv[inv])
175+
assert inp.shape == inv.shape
176+
158177

159178
@pytest.mark.parametrize(
160179
"dtype",
@@ -206,6 +225,15 @@ def test_unique_all_strided():
206225
assert inp.shape == inv.shape
207226
assert dpt.all(uv_counts == dpt.full(2, n / 2 * m, dtype=uv_counts.dtype))
208227

228+
inp = dpt.flip(dpt.reshape(inp, -1))
229+
230+
uv, ind, inv, uv_counts = dpt.unique_all(inp)
231+
assert dpt.all(uv == dpt.arange(2, dtype="i4"))
232+
assert dpt.all(uv == inp[ind])
233+
assert dpt.all(inp == uv[inv])
234+
assert inp.shape == inv.shape
235+
assert dpt.all(uv_counts == dpt.full(2, n / 2 * m, dtype=uv_counts.dtype))
236+
209237

210238
def test_set_functions_empty_input():
211239
get_queue_or_skip()

0 commit comments

Comments
 (0)