1
1
from typing import Literal , Union
2
2
3
+ import dpctl
3
4
import dpctl .utils as du
4
5
5
- from ._copy_utils import astype
6
+ from ._copy_utils import _empty_like_orderK
6
7
from ._ctors import empty
7
8
from ._data_types import int32 , int64
9
+ from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
10
+ from ._tensor_impl import _take as ti_take
8
11
from ._tensor_sorting_impl import _searchsorted_left , _searchsorted_right
9
12
from ._type_utils import iinfo , isdtype , result_type
10
13
from ._usmarray import usm_ndarray
@@ -74,6 +77,14 @@ def searchsorted(
74
77
"inferred from input arguments."
75
78
)
76
79
80
+ if x1 .ndim != 1 :
81
+ raise ValueError ("First argument array must be one-dimensional" )
82
+
83
+ x1_dt = x1 .dtype
84
+ x2_dt = x2 .dtype
85
+
86
+ host_evs = []
87
+ ev = dpctl .SyclEvent ()
77
88
if sorter is not None :
78
89
if not isdtype (sorter .dtype , "integral" ):
79
90
raise ValueError (
@@ -84,29 +95,78 @@ def searchsorted(
84
95
"Sorter array must be one-dimension with the same "
85
96
"shape as the first argument array"
86
97
)
87
- x1 = x1 [sorter ]
88
-
89
- if x1 .ndim != 1 :
90
- raise ValueError ("First argument array must be one-dimensional" )
98
+ res = empty (x1 .shape , dtype = x1_dt , usm_type = x1 .usm_type , sycl_queue = q )
99
+ ind = (sorter ,)
100
+ axis = 0
101
+ wrap_out_of_bound_indices_mode = 0
102
+ ht_ev , ev = ti_take (
103
+ x1 ,
104
+ ind ,
105
+ res ,
106
+ axis ,
107
+ wrap_out_of_bound_indices_mode ,
108
+ sycl_queue = q ,
109
+ depends = [
110
+ ev ,
111
+ ],
112
+ )
113
+ x1 = res
114
+ host_evs .append (ht_ev )
91
115
92
- if x1 . dtype != x2 . dtype :
116
+ if x1_dt != x2_dt :
93
117
dt = result_type (x1 , x2 )
94
- x1 = astype (x1 , dt , copy = None )
95
- x2 = astype (x2 , dt , copy = None )
118
+ if x1_dt != dt :
119
+ x1_buf = _empty_like_orderK (x1 , dt )
120
+ ht_ev , ev = ti_copy (
121
+ src = x1 ,
122
+ dst = x1_buf ,
123
+ sycl_queue = q ,
124
+ depends = [
125
+ ev ,
126
+ ],
127
+ )
128
+ host_evs .append (ht_ev )
129
+ x1 = x1_buf
130
+ if x2_dt != dt :
131
+ x2_buf = _empty_like_orderK (x2 , dt )
132
+ ht_ev , ev = ti_copy (
133
+ src = x2 ,
134
+ dst = x2_buf ,
135
+ sycl_queue = q ,
136
+ depends = [
137
+ ev ,
138
+ ],
139
+ )
140
+ host_evs .append (ht_ev )
141
+ x2 = x2_buf
96
142
97
143
dst_usm_type = du .get_coerced_usm_type ([x1 .usm_type , x2 .usm_type ])
98
144
dst_dt = int32 if x2 .size <= iinfo (int32 ).max else int64
99
145
100
- dst = empty (x2 . shape , dtype = dst_dt , usm_type = dst_usm_type , sycl_queue = q )
146
+ dst = _empty_like_orderK (x2 , dst_dt , usm_type = dst_usm_type )
101
147
102
148
if side == "left" :
103
149
ht_ev , _ = _searchsorted_left (
104
- hay = x1 , needles = x2 , positions = dst , sycl_queue = q
150
+ hay = x1 ,
151
+ needles = x2 ,
152
+ positions = dst ,
153
+ sycl_queue = q ,
154
+ depends = [
155
+ ev ,
156
+ ],
105
157
)
106
158
else :
107
159
ht_ev , _ = _searchsorted_right (
108
- hay = x1 , needles = x2 , positions = dst , sycl_queue = q
160
+ hay = x1 ,
161
+ needles = x2 ,
162
+ positions = dst ,
163
+ sycl_queue = q ,
164
+ depends = [
165
+ ev ,
166
+ ],
109
167
)
110
- ht_ev .wait ()
168
+
169
+ host_evs .append (ht_ev )
170
+ dpctl .SyclEvent .wait_for (host_evs )
111
171
112
172
return dst
0 commit comments