Skip to content

Commit 65baa6d

Browse files
authored
Fix tensor.searchsorted for x1 with strides and scalar (0D) x2 (#1693)
* Adds code to handle edge case of strided input and scalar `needle` in `searchsorted.cpp` * Adds a test for fix to gh-1689
1 parent 4dac76c commit 65baa6d

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

dpctl/tensor/libtensor/source/sorting/searchsorted.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -340,22 +340,29 @@ py_searchsorted(const dpctl::tensor::usm_ndarray &hay,
340340
int simplified_nd = needles_nd;
341341

342342
using shT = std::vector<py::ssize_t>;
343-
344343
shT simplified_common_shape;
345344
shT simplified_needles_strides;
346345
shT simplified_positions_strides;
347346
py::ssize_t needles_offset(0);
348347
py::ssize_t positions_offset(0);
349348

350-
dpctl::tensor::py_internal::simplify_iteration_space(
351-
// modified by refernce
352-
simplified_nd,
353-
// read-only inputs
354-
needles_shape_ptr, needles_strides, positions_strides,
355-
// output, modified by reference
356-
simplified_common_shape, simplified_needles_strides,
357-
simplified_positions_strides, needles_offset, positions_offset);
358-
349+
if (simplified_nd == 0) {
350+
// needles and positions have same nd
351+
simplified_nd = 1;
352+
simplified_common_shape.push_back(1);
353+
simplified_needles_strides.push_back(0);
354+
simplified_positions_strides.push_back(0);
355+
}
356+
else {
357+
dpctl::tensor::py_internal::simplify_iteration_space(
358+
// modified by refernce
359+
simplified_nd,
360+
// read-only inputs
361+
needles_shape_ptr, needles_strides, positions_strides,
362+
// output, modified by reference
363+
simplified_common_shape, simplified_needles_strides,
364+
simplified_positions_strides, needles_offset, positions_offset);
365+
}
359366
std::vector<sycl::event> host_task_events;
360367
host_task_events.reserve(2);
361368

dpctl/tests/test_usm_ndarray_searchsorted.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,19 @@ def test_out_of_bound_sorter_values():
355355
p = dpt.searchsorted(x, x2, sorter=sorter)
356356
# verify that they were applied with mode="wrap"
357357
assert dpt.all(p == dpt.arange(3, dtype=p.dtype))
358+
359+
360+
def test_searchsorted_strided_scalar_needle():
361+
get_queue_or_skip()
362+
363+
a_max = 255
364+
365+
hay_stack = dpt.flip(
366+
dpt.repeat(dpt.arange(a_max - 1, -1, -1, dtype=dpt.int32), 4)
367+
)
368+
needles_np = np.squeeze(
369+
np.random.randint(0, a_max, dtype=dpt.int32, size=1), axis=0
370+
)
371+
needles = dpt.asarray(needles_np)
372+
373+
_check(hay_stack, needles, needles_np)

0 commit comments

Comments
 (0)