Skip to content

Commit e5421fc

Browse files
Avoid unecessary synchronizations applying sorter, or type promoting
1 parent 95970e6 commit e5421fc

File tree

1 file changed

+72
-12
lines changed

1 file changed

+72
-12
lines changed

dpctl/tensor/_searchsorted.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from typing import Literal, Union
22

3+
import dpctl
34
import dpctl.utils as du
45

5-
from ._copy_utils import astype
6+
from ._copy_utils import _empty_like_orderK
67
from ._ctors import empty
78
from ._data_types import int32, int64
9+
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
10+
from ._tensor_impl import _take as ti_take
811
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
912
from ._type_utils import iinfo, isdtype, result_type
1013
from ._usmarray import usm_ndarray
@@ -74,6 +77,14 @@ def searchsorted(
7477
"inferred from input arguments."
7578
)
7679

80+
if x1.ndim != 1:
81+
raise ValueError("First argument array must be one-dimensional")
82+
83+
x1_dt = x1.dtype
84+
x2_dt = x2.dtype
85+
86+
host_evs = []
87+
ev = dpctl.SyclEvent()
7788
if sorter is not None:
7889
if not isdtype(sorter.dtype, "integral"):
7990
raise ValueError(
@@ -84,29 +95,78 @@ def searchsorted(
8495
"Sorter array must be one-dimension with the same "
8596
"shape as the first argument array"
8697
)
87-
x1 = x1[sorter]
88-
89-
if x1.ndim != 1:
90-
raise ValueError("First argument array must be one-dimensional")
98+
res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q)
99+
ind = (sorter,)
100+
axis = 0
101+
wrap_out_of_bound_indices_mode = 0
102+
ht_ev, ev = ti_take(
103+
x1,
104+
ind,
105+
res,
106+
axis,
107+
wrap_out_of_bound_indices_mode,
108+
sycl_queue=q,
109+
depends=[
110+
ev,
111+
],
112+
)
113+
x1 = res
114+
host_evs.append(ht_ev)
91115

92-
if x1.dtype != x2.dtype:
116+
if x1_dt != x2_dt:
93117
dt = result_type(x1, x2)
94-
x1 = astype(x1, dt, copy=None)
95-
x2 = astype(x2, dt, copy=None)
118+
if x1_dt != dt:
119+
x1_buf = _empty_like_orderK(x1, dt)
120+
ht_ev, ev = ti_copy(
121+
src=x1,
122+
dst=x1_buf,
123+
sycl_queue=q,
124+
depends=[
125+
ev,
126+
],
127+
)
128+
host_evs.append(ht_ev)
129+
x1 = x1_buf
130+
if x2_dt != dt:
131+
x2_buf = _empty_like_orderK(x2, dt)
132+
ht_ev, ev = ti_copy(
133+
src=x2,
134+
dst=x2_buf,
135+
sycl_queue=q,
136+
depends=[
137+
ev,
138+
],
139+
)
140+
host_evs.append(ht_ev)
141+
x2 = x2_buf
96142

97143
dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
98144
dst_dt = int32 if x2.size <= iinfo(int32).max else int64
99145

100-
dst = empty(x2.shape, dtype=dst_dt, usm_type=dst_usm_type, sycl_queue=q)
146+
dst = _empty_like_orderK(x2, dst_dt, usm_type=dst_usm_type)
101147

102148
if side == "left":
103149
ht_ev, _ = _searchsorted_left(
104-
hay=x1, needles=x2, positions=dst, sycl_queue=q
150+
hay=x1,
151+
needles=x2,
152+
positions=dst,
153+
sycl_queue=q,
154+
depends=[
155+
ev,
156+
],
105157
)
106158
else:
107159
ht_ev, _ = _searchsorted_right(
108-
hay=x1, needles=x2, positions=dst, sycl_queue=q
160+
hay=x1,
161+
needles=x2,
162+
positions=dst,
163+
sycl_queue=q,
164+
depends=[
165+
ev,
166+
],
109167
)
110-
ht_ev.wait()
168+
169+
host_evs.append(ht_ev)
170+
dpctl.SyclEvent.wait_for(host_evs)
111171

112172
return dst

0 commit comments

Comments
 (0)