Skip to content

Commit ab14950

Browse files
Used _searchsorted_left instead of tensor.searchsorted to avoid synchronization
Used asynchronous `_searchsorted_left` to insert operation into execution tree for unique_inverse and unique_all functions.
1 parent e5421fc commit ab14950

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

dpctl/tensor/_set_functions.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929
default_device_index_type,
3030
mask_positions,
3131
)
32-
from ._tensor_sorting_impl import _argsort_ascending, _sort_ascending
32+
from ._tensor_sorting_impl import (
33+
_argsort_ascending,
34+
_searchsorted_left,
35+
_sort_ascending,
36+
)
3337

3438
__all__ = [
3539
"unique_values",
@@ -365,7 +369,7 @@ def unique_inverse(x):
365369
unique_vals = dpt.empty(
366370
n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
367371
)
368-
ht_ev, _ = _extract(
372+
ht_ev, uv_ev = _extract(
369373
src=s,
370374
cumsum=cumsum,
371375
axis_start=0,
@@ -403,8 +407,21 @@ def unique_inverse(x):
403407
depends=[set_ev, extr_ev],
404408
)
405409
host_tasks.append(ht_ev)
410+
411+
inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32
412+
inv = dpt.empty_like(x, dtype=inv_dt, order="C")
413+
ht_ev, _ = _searchsorted_left(
414+
hay=unique_vals,
415+
needles=x,
416+
positions=inv,
417+
sycl_queue=exec_q,
418+
depends=[
419+
uv_ev,
420+
],
421+
)
422+
host_tasks.append(ht_ev)
423+
406424
dpctl.SyclEvent.wait_for(host_tasks)
407-
inv = dpt.searchsorted(unique_vals, x)
408425
return UniqueInverseResult(unique_vals, inv)
409426

410427

@@ -532,7 +549,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
532549
unique_vals = dpt.empty(
533550
n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
534551
)
535-
ht_ev, _ = _extract(
552+
ht_ev, uv_ev = _extract(
536553
src=s,
537554
cumsum=cumsum,
538555
axis_start=0,
@@ -562,15 +579,29 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
562579
)
563580
host_tasks.append(ht_ev)
564581
_counts = dpt.empty_like(cum_unique_counts[1:])
565-
ht_ev, _ = _subtract(
582+
ht_ev, sub_ev = _subtract(
566583
src1=cum_unique_counts[1:],
567584
src2=cum_unique_counts[:-1],
568585
dst=_counts,
569586
sycl_queue=exec_q,
570587
depends=[set_ev, extr_ev],
571588
)
572589
host_tasks.append(ht_ev)
573-
inv = dpt.searchsorted(unique_vals, x)
590+
591+
inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32
592+
inv = dpt.empty_like(x, dtype=inv_dt, order="C")
593+
ht_ev, _ = _searchsorted_left(
594+
hay=unique_vals,
595+
needles=x,
596+
positions=inv,
597+
sycl_queue=exec_q,
598+
depends=[
599+
uv_ev,
600+
],
601+
)
602+
host_tasks.append(ht_ev)
603+
604+
dpctl.SyclEvent.wait_for(host_tasks)
574605
return UniqueAllResult(
575606
unique_vals,
576607
sorting_ids[cum_unique_counts[:-1]],

0 commit comments

Comments
 (0)