Skip to content

Commit cd5e7f7

Browse files
Merge pull request #1694 from IntelPython/backport-fix-for-searchsorted
Backport gh-1693 to maintenance/0.17.x branch
2 parents 862c133 + b93bf80 commit cd5e7f7

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ and complies with revision [2023.12](https://data-apis.org/array-api/2023.12/) o
3737
* Fixed bug in basic slicing of empty arrays: [gh-1680](https://github.com/IntelPython/dpctl/pull/1680)
3838
* Fixed bug in `tensor.bitwise_invert` for boolean input array: [gh-1681](https://github.com/IntelPython/dpctl/pull/1681)
3939
* Fixed bug in `tensor.repeat` on zero-size input arrays: [gh-1682](https://github.com/IntelPython/dpctl/pull/1682)
40+
* Fixed bug in `tensor.searchsorted` for 0d needle vector and strided hay: [gh-1694](https://github.com/IntelPython/dpctl/pull/1694)
4041

4142

4243
## [0.16.1] - Apr. 10, 2024

dpctl/tensor/libtensor/source/sorting/searchsorted.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -340,22 +340,29 @@ py_searchsorted(const dpctl::tensor::usm_ndarray &hay,
340340
int simplified_nd = needles_nd;
341341

342342
using shT = std::vector<py::ssize_t>;
343-
344343
shT simplified_common_shape;
345344
shT simplified_needles_strides;
346345
shT simplified_positions_strides;
347346
py::ssize_t needles_offset(0);
348347
py::ssize_t positions_offset(0);
349348

350-
dpctl::tensor::py_internal::simplify_iteration_space(
351-
// modified by refernce
352-
simplified_nd,
353-
// read-only inputs
354-
needles_shape_ptr, needles_strides, positions_strides,
355-
// output, modified by reference
356-
simplified_common_shape, simplified_needles_strides,
357-
simplified_positions_strides, needles_offset, positions_offset);
358-
349+
if (simplified_nd == 0) {
350+
// needles and positions have same nd
351+
simplified_nd = 1;
352+
simplified_common_shape.push_back(1);
353+
simplified_needles_strides.push_back(0);
354+
simplified_positions_strides.push_back(0);
355+
}
356+
else {
357+
dpctl::tensor::py_internal::simplify_iteration_space(
358+
// modified by refernce
359+
simplified_nd,
360+
// read-only inputs
361+
needles_shape_ptr, needles_strides, positions_strides,
362+
// output, modified by reference
363+
simplified_common_shape, simplified_needles_strides,
364+
simplified_positions_strides, needles_offset, positions_offset);
365+
}
359366
std::vector<sycl::event> host_task_events;
360367
host_task_events.reserve(2);
361368

dpctl/tests/test_usm_ndarray_searchsorted.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,19 @@ def test_out_of_bound_sorter_values():
355355
p = dpt.searchsorted(x, x2, sorter=sorter)
356356
# verify that they were applied with mode="wrap"
357357
assert dpt.all(p == dpt.arange(3, dtype=p.dtype))
358+
359+
360+
def test_searchsorted_strided_scalar_needle():
361+
get_queue_or_skip()
362+
363+
a_max = 255
364+
365+
hay_stack = dpt.flip(
366+
dpt.repeat(dpt.arange(a_max - 1, -1, -1, dtype=dpt.int32), 4)
367+
)
368+
needles_np = np.squeeze(
369+
np.random.randint(0, a_max, dtype=dpt.int32, size=1), axis=0
370+
)
371+
needles = dpt.asarray(needles_np)
372+
373+
_check(hay_stack, needles, needles_np)

0 commit comments

Comments
 (0)