Skip to content

Commit d20e08d

Browse files
Merge pull request #1584 from IntelPython/feature/searchsorted
Implement array-API's searchsorted
2 parents 6065822 + ab14950 commit d20e08d

File tree

14 files changed

+1436
-78
lines changed

14 files changed

+1436
-78
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ set(_reduction_sources
120120
set(_sorting_sources
121121
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/sort.cpp
122122
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/argsort.cpp
123+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
123124
)
124125
set(_tensor_impl_sources
125126
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_ctors.cpp
@@ -152,6 +153,7 @@ set(_tensor_reductions_impl_sources
152153
)
153154
set(_tensor_sorting_impl_sources
154155
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
156+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
155157
${_sorting_sources}
156158
)
157159
set(_linalg_sources

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@
181181
reduce_hypot,
182182
sum,
183183
)
184+
from ._searchsorted import searchsorted
184185
from ._set_functions import (
185186
unique_all,
186187
unique_counts,
@@ -365,4 +366,5 @@
365366
"matmul",
366367
"tensordot",
367368
"vecdot",
369+
"searchsorted",
368370
]

dpctl/tensor/_searchsorted.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from typing import Literal, Union
2+
3+
import dpctl
4+
import dpctl.utils as du
5+
6+
from ._copy_utils import _empty_like_orderK
7+
from ._ctors import empty
8+
from ._data_types import int32, int64
9+
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
10+
from ._tensor_impl import _take as ti_take
11+
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
12+
from ._type_utils import iinfo, isdtype, result_type
13+
from ._usmarray import usm_ndarray
14+
15+
16+
def searchsorted(
17+
x1: usm_ndarray,
18+
x2: usm_ndarray,
19+
/,
20+
*,
21+
side: Literal["left", "right"] = "left",
22+
sorter: Union[usm_ndarray, None] = None,
23+
):
24+
"""searchsorted(x1, x2, side='left', sorter=None)
25+
26+
Finds the indices into `x1` such that, if the corresponding elements
27+
in `x2` were inserted before the indices, the order of `x1`, when sorted
28+
in ascending order, would be preserved.
29+
30+
Args:
31+
x1 (usm_ndarray):
32+
input array. Must be a one-dimensional array. If `sorter` is
33+
`None`, must be sorted in ascending order; otherwise, `sorter` must
34+
be an array of indices that sort `x1` in ascending order.
35+
x2 (usm_ndarray):
36+
array containing search values.
37+
side (Literal["left", "right]):
38+
argument controlling which index is returned if a value lands
39+
exactly on an edge. If `x2` is an array of rank `N` where
40+
`v = x2[n, m, ..., j]`, the element `ret[n, m, ..., j]` in the
41+
return array `ret` contains the position `i` such that
42+
if `side="left"`, it is the first index such that
43+
`x1[i-1] < v <= x1[i]`, `0` if `v <= x1[0]`, and `x1.size`
44+
if `v > x1[-1]`;
45+
and if `side="right"`, it is the first position `i` such that
46+
`x1[i-1] <= v < x1[i]`, `0` if `v < x1[0]`, and `x1.size`
47+
if `v >= x1[-1]`. Default: `"left"`.
48+
sorter (Optional[usm_ndarray]):
49+
array of indices that sort `x1` in ascending order. The array must
50+
have the same shape as `x1` and have an integral data type.
51+
Default: `None`.
52+
"""
53+
if not isinstance(x1, usm_ndarray):
54+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}")
55+
if not isinstance(x2, usm_ndarray):
56+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}")
57+
if sorter is not None and not isinstance(sorter, usm_ndarray):
58+
raise TypeError(
59+
f"Expected dpctl.tensor.usm_ndarray, got {type(sorter)}"
60+
)
61+
62+
if side not in ["left", "right"]:
63+
raise ValueError(
64+
"Unrecognized value of 'side' keyword argument. "
65+
"Expected either 'left' or 'right'"
66+
)
67+
68+
if sorter is None:
69+
q = du.get_execution_queue([x1.sycl_queue, x2.sycl_queue])
70+
else:
71+
q = du.get_execution_queue(
72+
[x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue]
73+
)
74+
if q is None:
75+
raise du.ExecutionPlacementError(
76+
"Execution placement can not be unambiguously "
77+
"inferred from input arguments."
78+
)
79+
80+
if x1.ndim != 1:
81+
raise ValueError("First argument array must be one-dimensional")
82+
83+
x1_dt = x1.dtype
84+
x2_dt = x2.dtype
85+
86+
host_evs = []
87+
ev = dpctl.SyclEvent()
88+
if sorter is not None:
89+
if not isdtype(sorter.dtype, "integral"):
90+
raise ValueError(
91+
f"Sorter array must have integral data type, got {sorter.dtype}"
92+
)
93+
if x1.shape != sorter.shape:
94+
raise ValueError(
95+
"Sorter array must be one-dimension with the same "
96+
"shape as the first argument array"
97+
)
98+
res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q)
99+
ind = (sorter,)
100+
axis = 0
101+
wrap_out_of_bound_indices_mode = 0
102+
ht_ev, ev = ti_take(
103+
x1,
104+
ind,
105+
res,
106+
axis,
107+
wrap_out_of_bound_indices_mode,
108+
sycl_queue=q,
109+
depends=[
110+
ev,
111+
],
112+
)
113+
x1 = res
114+
host_evs.append(ht_ev)
115+
116+
if x1_dt != x2_dt:
117+
dt = result_type(x1, x2)
118+
if x1_dt != dt:
119+
x1_buf = _empty_like_orderK(x1, dt)
120+
ht_ev, ev = ti_copy(
121+
src=x1,
122+
dst=x1_buf,
123+
sycl_queue=q,
124+
depends=[
125+
ev,
126+
],
127+
)
128+
host_evs.append(ht_ev)
129+
x1 = x1_buf
130+
if x2_dt != dt:
131+
x2_buf = _empty_like_orderK(x2, dt)
132+
ht_ev, ev = ti_copy(
133+
src=x2,
134+
dst=x2_buf,
135+
sycl_queue=q,
136+
depends=[
137+
ev,
138+
],
139+
)
140+
host_evs.append(ht_ev)
141+
x2 = x2_buf
142+
143+
dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
144+
dst_dt = int32 if x2.size <= iinfo(int32).max else int64
145+
146+
dst = _empty_like_orderK(x2, dst_dt, usm_type=dst_usm_type)
147+
148+
if side == "left":
149+
ht_ev, _ = _searchsorted_left(
150+
hay=x1,
151+
needles=x2,
152+
positions=dst,
153+
sycl_queue=q,
154+
depends=[
155+
ev,
156+
],
157+
)
158+
else:
159+
ht_ev, _ = _searchsorted_right(
160+
hay=x1,
161+
needles=x2,
162+
positions=dst,
163+
sycl_queue=q,
164+
depends=[
165+
ev,
166+
],
167+
)
168+
169+
host_evs.append(ht_ev)
170+
dpctl.SyclEvent.wait_for(host_evs)
171+
172+
return dst

dpctl/tensor/_set_functions.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929
default_device_index_type,
3030
mask_positions,
3131
)
32-
from ._tensor_sorting_impl import _argsort_ascending, _sort_ascending
32+
from ._tensor_sorting_impl import (
33+
_argsort_ascending,
34+
_searchsorted_left,
35+
_sort_ascending,
36+
)
3337

3438
__all__ = [
3539
"unique_values",
@@ -365,7 +369,7 @@ def unique_inverse(x):
365369
unique_vals = dpt.empty(
366370
n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
367371
)
368-
ht_ev, _ = _extract(
372+
ht_ev, uv_ev = _extract(
369373
src=s,
370374
cumsum=cumsum,
371375
axis_start=0,
@@ -403,23 +407,22 @@ def unique_inverse(x):
403407
depends=[set_ev, extr_ev],
404408
)
405409
host_tasks.append(ht_ev)
406-
# TODO: when searchsorted is available,
407-
# inv = searchsorted(unique_vals, fx)
408-
dpctl.SyclEvent.wait_for(host_tasks)
409-
counts = dpt.asnumpy(_counts).tolist()
410-
inv = dpt.empty_like(fx, dtype=ind_dt)
411-
pos = 0
412-
host_tasks = []
413-
for i in range(len(counts)):
414-
pos_next = pos + counts[i]
415-
_dst = inv[pos:pos_next]
416-
ht_ev, _ = _full_usm_ndarray(fill_value=i, dst=_dst, sycl_queue=exec_q)
417-
pos = pos_next
418-
host_tasks.append(ht_ev)
419-
dpctl.SyclEvent.wait_for(host_tasks)
420-
return UniqueInverseResult(
421-
unique_vals, dpt.reshape(inv[unsorting_ids], x.shape)
410+
411+
inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32
412+
inv = dpt.empty_like(x, dtype=inv_dt, order="C")
413+
ht_ev, _ = _searchsorted_left(
414+
hay=unique_vals,
415+
needles=x,
416+
positions=inv,
417+
sycl_queue=exec_q,
418+
depends=[
419+
uv_ev,
420+
],
422421
)
422+
host_tasks.append(ht_ev)
423+
424+
dpctl.SyclEvent.wait_for(host_tasks)
425+
return UniqueInverseResult(unique_vals, inv)
423426

424427

425428
def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
@@ -546,7 +549,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
546549
unique_vals = dpt.empty(
547550
n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
548551
)
549-
ht_ev, _ = _extract(
552+
ht_ev, uv_ev = _extract(
550553
src=s,
551554
cumsum=cumsum,
552555
axis_start=0,
@@ -576,31 +579,32 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
576579
)
577580
host_tasks.append(ht_ev)
578581
_counts = dpt.empty_like(cum_unique_counts[1:])
579-
ht_ev, _ = _subtract(
582+
ht_ev, sub_ev = _subtract(
580583
src1=cum_unique_counts[1:],
581584
src2=cum_unique_counts[:-1],
582585
dst=_counts,
583586
sycl_queue=exec_q,
584587
depends=[set_ev, extr_ev],
585588
)
586589
host_tasks.append(ht_ev)
587-
# TODO: when searchsorted is available,
588-
# inv = searchsorted(unique_vals, fx)
589-
dpctl.SyclEvent.wait_for(host_tasks)
590-
counts = dpt.asnumpy(_counts).tolist()
591-
inv = dpt.empty_like(fx, dtype=ind_dt)
592-
pos = 0
593-
host_tasks = []
594-
for i in range(len(counts)):
595-
pos_next = pos + counts[i]
596-
_dst = inv[pos:pos_next]
597-
ht_ev, _ = _full_usm_ndarray(fill_value=i, dst=_dst, sycl_queue=exec_q)
598-
pos = pos_next
599-
host_tasks.append(ht_ev)
590+
591+
inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32
592+
inv = dpt.empty_like(x, dtype=inv_dt, order="C")
593+
ht_ev, _ = _searchsorted_left(
594+
hay=unique_vals,
595+
needles=x,
596+
positions=inv,
597+
sycl_queue=exec_q,
598+
depends=[
599+
uv_ev,
600+
],
601+
)
602+
host_tasks.append(ht_ev)
603+
600604
dpctl.SyclEvent.wait_for(host_tasks)
601605
return UniqueAllResult(
602606
unique_vals,
603607
sorting_ids[cum_unique_counts[:-1]],
604-
dpt.reshape(inv[unsorting_ids], x.shape),
608+
inv,
605609
_counts,
606610
)

0 commit comments

Comments
 (0)