Skip to content

Commit 0b66dce

Browse files
committed
Implement dpnp.searchsorted
1 parent 1b58244 commit 0b66dce

File tree

13 files changed

+651
-394
lines changed

13 files changed

+651
-394
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,11 @@ enum class DPNPFuncName : size_t
339339
DPNP_FN_RNG_ZIPF_EXT, /**< Used in numpy.random.zipf() impl, requires extra
340340
parameters */
341341
DPNP_FN_SEARCHSORTED, /**< Used in numpy.searchsorted() impl */
342-
DPNP_FN_SEARCHSORTED_EXT, /**< Used in numpy.searchsorted() impl, requires
343-
extra parameters */
344-
DPNP_FN_SIGN, /**< Used in numpy.sign() impl */
345-
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
346-
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
347-
DPNP_FN_SORT, /**< Used in numpy.sort() impl */
348-
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
342+
DPNP_FN_SIGN, /**< Used in numpy.sign() impl */
343+
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
344+
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
345+
DPNP_FN_SORT, /**< Used in numpy.sort() impl */
346+
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
349347
DPNP_FN_SQRT_EXT, /**< Used in numpy.sqrt() impl, requires extra parameters
350348
*/
351349
DPNP_FN_SQUARE, /**< Used in numpy.square() impl */

dpnp/backend/kernels/dpnp_krnl_sorting.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -403,17 +403,6 @@ void (*dpnp_searchsorted_default_c)(void *,
403403
const size_t) =
404404
dpnp_searchsorted_c<_DataType, _IndexingType>;
405405

406-
template <typename _DataType, typename _IndexingType>
407-
DPCTLSyclEventRef (*dpnp_searchsorted_ext_c)(DPCTLSyclQueueRef,
408-
void *,
409-
const void *,
410-
const void *,
411-
bool,
412-
const size_t,
413-
const size_t,
414-
const DPCTLEventVectorRef) =
415-
dpnp_searchsorted_c<_DataType, _IndexingType>;
416-
417406
template <typename _DataType>
418407
class dpnp_sort_c_kernel;
419408

@@ -507,15 +496,6 @@ void func_map_init_sorting(func_map_t &fmap)
507496
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED][eft_DBL][eft_DBL] = {
508497
eft_DBL, (void *)dpnp_searchsorted_default_c<double, int64_t>};
509498

510-
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_INT][eft_INT] = {
511-
eft_INT, (void *)dpnp_searchsorted_ext_c<int32_t, int64_t>};
512-
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_LNG][eft_LNG] = {
513-
eft_LNG, (void *)dpnp_searchsorted_ext_c<int64_t, int64_t>};
514-
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_FLT][eft_FLT] = {
515-
eft_FLT, (void *)dpnp_searchsorted_ext_c<float, int64_t>};
516-
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_DBL][eft_DBL] = {
517-
eft_DBL, (void *)dpnp_searchsorted_ext_c<double, int64_t>};
518-
519499
fmap[DPNPFuncName::DPNP_FN_SORT][eft_INT][eft_INT] = {
520500
eft_INT, (void *)dpnp_sort_default_c<int32_t>};
521501
fmap[DPNPFuncName::DPNP_FN_SORT][eft_LNG][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
159159
DPNP_FN_RNG_WEIBULL_EXT
160160
DPNP_FN_RNG_ZIPF
161161
DPNP_FN_RNG_ZIPF_EXT
162-
DPNP_FN_SEARCHSORTED
163-
DPNP_FN_SEARCHSORTED_EXT
164162
DPNP_FN_TRACE
165163
DPNP_FN_TRACE_EXT
166164
DPNP_FN_TRANSPOSE

dpnp/dpnp_algo/dpnp_algo_sorting.pxi

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ and the rest of the library
3737

3838
__all__ += [
3939
"dpnp_partition",
40-
"dpnp_searchsorted",
4140
]
4241

4342

@@ -49,14 +48,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_partition_t)(c_dpctl.DPCTLSyclQueu
4948
const shape_elem_type * ,
5049
const size_t,
5150
const c_dpctl.DPCTLEventVectorRef)
52-
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_searchsorted_t)(c_dpctl.DPCTLSyclQueueRef,
53-
void * ,
54-
const void * ,
55-
const void * ,
56-
bool,
57-
const size_t,
58-
const size_t,
59-
const c_dpctl.DPCTLEventVectorRef)
6051

6152

6253
cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, axis=-1, kind='introselect', order=None):
@@ -98,44 +89,3 @@ cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, a
9889
c_dpctl.DPCTLEvent_Delete(event_ref)
9990

10091
return result
101-
102-
103-
cpdef utils.dpnp_descriptor dpnp_searchsorted(utils.dpnp_descriptor arr, utils.dpnp_descriptor v, side='left'):
104-
if side is 'left':
105-
side_ = True
106-
else:
107-
side_ = False
108-
109-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
110-
111-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_SEARCHSORTED_EXT, param1_type, param1_type)
112-
113-
arr_obj = arr.get_array()
114-
115-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(v.shape,
116-
dpnp.int64,
117-
None,
118-
device=arr_obj.sycl_device,
119-
usm_type=arr_obj.usm_type,
120-
sycl_queue=arr_obj.sycl_queue)
121-
122-
result_sycl_queue = result.get_array().sycl_queue
123-
124-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
125-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
126-
127-
cdef fptr_dpnp_searchsorted_t func = <fptr_dpnp_searchsorted_t > kernel_data.ptr
128-
129-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
130-
arr.get_data(),
131-
v.get_data(),
132-
result.get_data(),
133-
side_,
134-
arr.size,
135-
v.size,
136-
NULL) # dep_events_ref
137-
138-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
139-
c_dpctl.DPCTLEvent_Delete(event_ref)
140-
141-
return result

dpnp/dpnp_array.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,17 @@ def round(self, decimals=0, out=None):
11211121

11221122
return dpnp.around(self, decimals, out)
11231123

1124-
# 'searchsorted',
1124+
def searchsorted(self, v, side="left", sorter=None):
1125+
"""
1126+
Find indices where elements of `v` should be inserted in `a`
1127+
to maintain order.
1128+
1129+
Refer to :obj:`dpnp.searchsorted` for full documentation
1130+
1131+
"""
1132+
1133+
return dpnp.searchsorted(self, v, side=side, sorter=sorter)
1134+
11251135
# 'setfield',
11261136
# 'setflags',
11271137

dpnp/dpnp_iface_searching.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,57 @@ def searchsorted(a, v, side="left", sorter=None):
232232
233233
For full documentation refer to :obj:`numpy.searchsorted`.
234234
235+
Parameters
236+
----------
237+
a : {dpnp.ndarray, usm_ndarray}
238+
Input 1-D array. If `sorter` is ``None``, then it must be sorted in
239+
ascending order, otherwise `sorter` must be an array of indices that
240+
sort it.
241+
v : {dpnp.ndarray, usm_ndarray, scalar}
242+
Values to insert into `a`.
243+
side : {'left', 'right'}, optional
244+
If ``'left'``, the index of the first suitable location found is given.
245+
If ``'right'``, return the last such index. If there is no suitable
246+
index, return either 0 or N (where N is the length of `a`).
247+
sorter : {dpnp.ndarray, usm_ndarray}, optional
248+
Optional 1-D array of integer indices that sort array a into ascending
249+
order. They are typically the result of argsort.
250+
251+
Returns
252+
-------
253+
indices : dpnp.ndarray
254+
Array of insertion points with the same shape as `v`,
255+
or 0-D array if `v` is a scalar.
256+
257+
See Also
258+
--------
259+
:obj:`dpnp.sort` : Return a sorted copy of an array.
260+
:obj:`dpnp.histogram` : Produce histogram from 1-D data.
261+
262+
Examples
263+
--------
264+
>>> import dpnp as np
265+
>>> a = np.array([11,12,13,14,15])
266+
>>> np.searchsorted(a, 13)
267+
array(2)
268+
>>> np.searchsorted(a, 13, side='right')
269+
array(3)
270+
>>> v = np.array([-10, 20, 12, 13])
271+
>>> np.searchsorted(a, v)
272+
array([0, 5, 1, 2])
273+
235274
"""
236275

237-
return call_origin(numpy.where, a, v, side, sorter)
276+
usm_a = dpnp.get_usm_ndarray(a)
277+
if dpnp.isscalar(v):
278+
usm_v = dpt.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
279+
else:
280+
usm_v = dpnp.get_usm_ndarray(v)
281+
282+
usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter)
283+
return dpnp_array._create_from_usm_ndarray(
284+
dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter)
285+
)
238286

239287

240288
def where(condition, x=None, y=None, /):

dpnp/dpnp_iface_sorting.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,13 @@
4747
# pylint: disable=no-name-in-module
4848
from .dpnp_algo import (
4949
dpnp_partition,
50-
dpnp_searchsorted,
5150
)
5251
from .dpnp_array import dpnp_array
5352
from .dpnp_utils import (
5453
call_origin,
5554
)
5655

57-
__all__ = ["argsort", "partition", "searchsorted", "sort"]
56+
__all__ = ["argsort", "partition", "sort"]
5857

5958

6059
def argsort(a, axis=-1, kind=None, order=None):
@@ -189,41 +188,6 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None):
189188
return call_origin(numpy.partition, x1, kth, axis, kind, order)
190189

191190

192-
def searchsorted(x1, x2, side="left", sorter=None):
193-
"""
194-
Find indices where elements should be inserted to maintain order.
195-
196-
For full documentation refer to :obj:`numpy.searchsorted`.
197-
198-
Limitations
199-
-----------
200-
Input arrays is supported as :obj:`dpnp.ndarray`.
201-
Input array is supported only sorted.
202-
Input side is supported only values ``left``, ``right``.
203-
Parameter `sorter` is supported only with default values.
204-
205-
"""
206-
207-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
208-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
209-
# pylint: disable=condition-evals-to-constant
210-
if 0 and x1_desc and x2_desc:
211-
if x1_desc.ndim != 1:
212-
pass
213-
elif x1_desc.dtype != x2_desc.dtype:
214-
pass
215-
elif side not in ["left", "right"]:
216-
pass
217-
elif sorter is not None:
218-
pass
219-
elif x1_desc.size < 2:
220-
pass
221-
else:
222-
return dpnp_searchsorted(x1_desc, x2_desc, side=side).get_pyobj()
223-
224-
return call_origin(numpy.searchsorted, x1, x2, side=side, sorter=sorter)
225-
226-
227191
def sort(a, axis=-1, kind=None, order=None):
228192
"""
229193
Return a sorted copy of an array.

tests/skipped_tests.tbl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -705,28 +705,6 @@ tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bo
705705
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit
706706
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit_2
707707

708-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_0_{func='argmin', is_module=True, shape=(3, 4)}]
709-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_1_{func='argmin', is_module=True, shape=()}]
710-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_2_{func='argmin', is_module=False, shape=(3, 4)}]
711-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_3_{func='argmin', is_module=False, shape=()}]
712-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_4_{func='argmax', is_module=True, shape=(3, 4)}]
713-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_5_{func='argmax', is_module=True, shape=()}]
714-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_6_{func='argmax', is_module=False, shape=(3, 4)}]
715-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_7_{func='argmax', is_module=False, shape=()}]
716-
717-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[0]
718-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[1]
719-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[2]
720-
721-
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[0]
722-
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[1]
723-
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[2]
724-
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[3]
725-
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[4]
726-
727-
tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_0_{array=array(0)}::test_nonzero
728-
tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_1_{array=array(1)}::test_nonzero
729-
730708
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_axis
731709
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis1
732710
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis2

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -767,28 +767,6 @@ tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bo
767767
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit
768768
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit_2
769769

770-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_0_{func='argmin', is_module=True, shape=(3, 4)}]
771-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_1_{func='argmin', is_module=True, shape=()}]
772-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_2_{func='argmin', is_module=False, shape=(3, 4)}]
773-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_3_{func='argmin', is_module=False, shape=()}]
774-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_4_{func='argmax', is_module=True, shape=(3, 4)}]
775-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_5_{func='argmax', is_module=True, shape=()}]
776-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_6_{func='argmax', is_module=False, shape=(3, 4)}]
777-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_7_{func='argmax', is_module=False, shape=()}]
778-
779-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[0]
780-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[1]
781-
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[2]
782-
783-
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[0]
784-
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[1]
785-
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[2]
786-
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[3]
787-
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[4]
788-
789-
tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_0_{array=array(0)}::test_nonzero
790-
tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_1_{array=array(1)}::test_nonzero
791-
792770
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_axis
793771
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis1
794772
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis2

0 commit comments

Comments
 (0)