Skip to content

Commit 4ecf30a

Browse files
committed
Fixes to unique functions, adds tests
`unique_all` would not return a `UniqueAllResult` tuple when early exiting for an array of all unique elements Fixed problems through `unique` functions where the cumulative sum was being allocated on a separate (default) queue, causing inputs with a non-default queue to fail `unique` functions now behave correctly for 0-size array inputs
1 parent 06a7970 commit 4ecf30a

File tree

2 files changed

+112
-14
lines changed

2 files changed

+112
-14
lines changed

dpctl/tensor/_set_functions_async.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray:
8080
fx = x
8181
else:
8282
fx = dpt.reshape(x, (x.size,), order="C")
83+
if fx.size == 0:
84+
return fx
8385
s = dpt.empty_like(fx, order="C")
8486
host_tasks = []
8587
if fx.flags.c_contiguous:
@@ -114,7 +116,7 @@ def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray:
114116
fill_value=True, dst=unique_mask[0], sycl_queue=exec_q
115117
)
116118
host_tasks.append(ht_ev)
117-
cumsum = dpt.empty(s.shape, dtype=dpt.int64)
119+
cumsum = dpt.empty(s.shape, dtype=dpt.int64, sycl_queue=exec_q)
118120
# synchronizing call
119121
n_uniques = mask_positions(
120122
unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev]
@@ -163,10 +165,14 @@ def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult:
163165
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
164166
array_api_dev = x.device
165167
exec_q = array_api_dev.sycl_queue
168+
x_usm_type = x.usm_type
166169
if x.ndim == 1:
167170
fx = x
168171
else:
169172
fx = dpt.reshape(x, (x.size,), order="C")
173+
ind_dt = default_device_index_type(exec_q)
174+
if fx.size == 0:
175+
return UniqueCountsResult(fx, dpt.empty_like(fx, dtype=ind_dt))
170176
s = dpt.empty_like(fx, order="C")
171177
host_tasks = []
172178
if fx.flags.c_contiguous:
@@ -201,17 +207,21 @@ def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult:
201207
fill_value=True, dst=unique_mask[0], sycl_queue=exec_q
202208
)
203209
host_tasks.append(ht_ev)
204-
ind_dt = default_device_index_type(exec_q)
205-
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64)
210+
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q)
206211
# synchronizing call
207212
n_uniques = mask_positions(
208213
unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev]
209214
)
210215
if n_uniques == fx.size:
211216
dpctl.SyclEvent.wait_for(host_tasks)
212-
return UniqueCountsResult(s, dpt.ones(n_uniques, dtype=ind_dt))
217+
return UniqueCountsResult(
218+
s,
219+
dpt.ones(
220+
n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
221+
),
222+
)
213223
unique_vals = dpt.empty(
214-
n_uniques, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
224+
n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
215225
)
216226
# populate unique values
217227
ht_ev, _ = _extract(
@@ -224,7 +234,7 @@ def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult:
224234
)
225235
host_tasks.append(ht_ev)
226236
unique_counts = dpt.empty(
227-
n_uniques + 1, dtype=ind_dt, usm_type=x.usm_type, sycl_queue=exec_q
237+
n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
228238
)
229239
idx = dpt.empty(x.size, dtype=ind_dt, sycl_queue=exec_q)
230240
ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q)
@@ -281,13 +291,16 @@ def unique_inverse(x):
281291
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
282292
array_api_dev = x.device
283293
exec_q = array_api_dev.sycl_queue
294+
x_usm_type = x.usm_type
284295
if x.ndim == 1:
285296
fx = x
286297
else:
287298
fx = dpt.reshape(x, (x.size,), order="C")
288299
ind_dt = default_device_index_type(exec_q)
289300
sorting_ids = dpt.empty_like(fx, dtype=ind_dt, order="C")
290301
unsorting_ids = dpt.empty_like(sorting_ids, dtype=ind_dt, order="C")
302+
if fx.size == 0:
303+
return UniqueInverseResult(fx, unsorting_ids)
291304
host_tasks = []
292305
if fx.flags.c_contiguous:
293306
ht_ev, sort_ev = _argsort_ascending(
@@ -341,7 +354,7 @@ def unique_inverse(x):
341354
fill_value=True, dst=unique_mask[0], sycl_queue=exec_q
342355
)
343356
host_tasks.append(ht_ev)
344-
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64)
357+
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q)
345358
# synchronizing call
346359
n_uniques = mask_positions(
347360
unique_mask, cumsum, sycl_queue=exec_q, depends=[uneq_ev, one_ev]
@@ -350,7 +363,7 @@ def unique_inverse(x):
350363
dpctl.SyclEvent.wait_for(host_tasks)
351364
return UniqueInverseResult(s, unsorting_ids)
352365
unique_vals = dpt.empty(
353-
n_uniques, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
366+
n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
354367
)
355368
ht_ev, _ = _extract(
356369
src=s,
@@ -362,7 +375,7 @@ def unique_inverse(x):
362375
)
363376
host_tasks.append(ht_ev)
364377
cum_unique_counts = dpt.empty(
365-
n_uniques + 1, dtype=ind_dt, usm_type=x.usm_type, sycl_queue=exec_q
378+
n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
366379
)
367380
idx = dpt.empty(x.size, dtype=ind_dt, sycl_queue=exec_q)
368381
ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q)
@@ -442,13 +455,20 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
442455
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
443456
array_api_dev = x.device
444457
exec_q = array_api_dev.sycl_queue
458+
x_usm_type = x.usm_type
445459
if x.ndim == 1:
446460
fx = x
447461
else:
448-
fx = dpt.reshape(x, (x.size,), order="C", copy=False)
462+
fx = dpt.reshape(x, (x.size,), order="C")
449463
ind_dt = default_device_index_type(exec_q)
450464
sorting_ids = dpt.empty_like(fx, dtype=ind_dt, order="C")
451465
unsorting_ids = dpt.empty_like(sorting_ids, dtype=ind_dt, order="C")
466+
if fx.size == 0:
467+
# original array contains no data
468+
# so it can be safely returned as values
469+
return UniqueAllResult(
470+
fx, sorting_ids, unsorting_ids, dpt.empty_like(fx, dtype=ind_dt)
471+
)
452472
host_tasks = []
453473
if fx.flags.c_contiguous:
454474
ht_ev, sort_ev = _argsort_ascending(
@@ -502,16 +522,24 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
502522
fill_value=True, dst=unique_mask[0], sycl_queue=exec_q
503523
)
504524
host_tasks.append(ht_ev)
505-
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64)
525+
cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q)
506526
# synchronizing call
507527
n_uniques = mask_positions(
508528
unique_mask, cumsum, sycl_queue=exec_q, depends=[uneq_ev, one_ev]
509529
)
510530
if n_uniques == fx.size:
511531
dpctl.SyclEvent.wait_for(host_tasks)
512-
return UniqueInverseResult(s, unsorting_ids)
532+
_counts = dpt.ones(
533+
n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
534+
)
535+
return UniqueAllResult(
536+
s,
537+
sorting_ids,
538+
unsorting_ids,
539+
_counts,
540+
)
513541
unique_vals = dpt.empty(
514-
n_uniques, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
542+
n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
515543
)
516544
ht_ev, _ = _extract(
517545
src=s,
@@ -523,7 +551,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
523551
)
524552
host_tasks.append(ht_ev)
525553
cum_unique_counts = dpt.empty(
526-
n_uniques + 1, dtype=ind_dt, usm_type=x.usm_type, sycl_queue=exec_q
554+
n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
527555
)
528556
idx = dpt.empty(x.size, dtype=ind_dt, sycl_queue=exec_q)
529557
ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q)

dpctl/tests/test_usm_ndarray_unique.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest
1818

19+
import dpctl
1920
import dpctl.tensor as dpt
2021
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2122

@@ -151,3 +152,72 @@ def test_unique_all(dtype):
151152
assert dpt.all(uv == inp[ind])
152153
assert dpt.all(inp == uv[inv])
153154
assert dpt.all(uv_counts == dpt.full(2, n, dtype=uv_counts.dtype))
155+
156+
157+
def test_set_functions_empty_input():
158+
get_queue_or_skip()
159+
x = dpt.ones((10, 0, 1), dtype="i4")
160+
161+
res = dpt.unique_values(x)
162+
assert isinstance(res, dpctl.tensor.usm_ndarray)
163+
assert res.size == 0
164+
assert res.dtype == x.dtype
165+
166+
res = dpt.unique_inverse(x)
167+
assert type(res).__name__ == "UniqueInverseResult"
168+
uv, inv = res
169+
assert isinstance(uv, dpctl.tensor.usm_ndarray)
170+
assert uv.size == 0
171+
assert isinstance(inv, dpctl.tensor.usm_ndarray)
172+
assert inv.size == 0
173+
174+
res = dpt.unique_counts(x)
175+
assert type(res).__name__ == "UniqueCountsResult"
176+
uv, uv_counts = res
177+
assert isinstance(uv, dpctl.tensor.usm_ndarray)
178+
assert uv.size == 0
179+
assert isinstance(uv_counts, dpctl.tensor.usm_ndarray)
180+
assert uv_counts.size == 0
181+
182+
res = dpt.unique_all(x)
183+
assert type(res).__name__ == "UniqueAllResult"
184+
uv, ind, inv, uv_counts = res
185+
assert isinstance(uv, dpctl.tensor.usm_ndarray)
186+
assert uv.size == 0
187+
assert isinstance(ind, dpctl.tensor.usm_ndarray)
188+
assert ind.size == 0
189+
assert isinstance(inv, dpctl.tensor.usm_ndarray)
190+
assert inv.size == 0
191+
assert isinstance(uv_counts, dpctl.tensor.usm_ndarray)
192+
assert uv_counts.size == 0
193+
194+
195+
def test_set_function_outputs():
196+
get_queue_or_skip()
197+
# check standard and early exit paths
198+
x1 = dpt.arange(10, dtype="i4")
199+
x2 = dpt.ones((10, 10), dtype="i4")
200+
201+
assert isinstance(dpt.unique_values(x1), dpctl.tensor.usm_ndarray)
202+
assert isinstance(dpt.unique_values(x2), dpctl.tensor.usm_ndarray)
203+
204+
assert type(dpt.unique_inverse(x1)).__name__ == "UniqueInverseResult"
205+
assert type(dpt.unique_inverse(x2)).__name__ == "UniqueInverseResult"
206+
207+
assert type(dpt.unique_counts(x1)).__name__ == "UniqueCountsResult"
208+
assert type(dpt.unique_counts(x2)).__name__ == "UniqueCountsResult"
209+
210+
assert type(dpt.unique_all(x1)).__name__ == "UniqueAllResult"
211+
assert type(dpt.unique_all(x2)).__name__ == "UniqueAllResult"
212+
213+
214+
def test_set_functions_compute_follows_data():
215+
# tests that all intermediate calls and allocations
216+
# are compatible with an input with an arbitrary queue
217+
q = dpctl.SyclQueue()
218+
x = dpt.arange(10, dtype="i4", sycl_queue=q)
219+
220+
assert isinstance(dpt.unique_values(x), dpctl.tensor.usm_ndarray)
221+
assert dpt.unique_counts(x)
222+
assert dpt.unique_inverse(x)
223+
assert dpt.unique_all(x)

0 commit comments

Comments
 (0)