Skip to content

Commit 97b1c82

Browse files
committed
Adds code to handle edge case of strided input and scalar needle in searchsorted.cpp
1 parent 9df040a commit 97b1c82

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

dpctl/tensor/libtensor/source/sorting/searchsorted.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -340,22 +340,29 @@ py_searchsorted(const dpctl::tensor::usm_ndarray &hay,
340340
int simplified_nd = needles_nd;
341341

342342
using shT = std::vector<py::ssize_t>;
343-
344343
shT simplified_common_shape;
345344
shT simplified_needles_strides;
346345
shT simplified_positions_strides;
347346
py::ssize_t needles_offset(0);
348347
py::ssize_t positions_offset(0);
349348

350-
dpctl::tensor::py_internal::simplify_iteration_space(
351-
// modified by refernce
352-
simplified_nd,
353-
// read-only inputs
354-
needles_shape_ptr, needles_strides, positions_strides,
355-
// output, modified by reference
356-
simplified_common_shape, simplified_needles_strides,
357-
simplified_positions_strides, needles_offset, positions_offset);
358-
349+
if (simplified_nd == 0) {
350+
// needles and positions have same nd
351+
simplified_nd = 1;
352+
simplified_common_shape.push_back(1);
353+
simplified_needles_strides.push_back(0);
354+
simplified_positions_strides.push_back(0);
355+
}
356+
else {
357+
dpctl::tensor::py_internal::simplify_iteration_space(
358+
// modified by refernce
359+
simplified_nd,
360+
// read-only inputs
361+
needles_shape_ptr, needles_strides, positions_strides,
362+
// output, modified by reference
363+
simplified_common_shape, simplified_needles_strides,
364+
simplified_positions_strides, needles_offset, positions_offset);
365+
}
359366
std::vector<sycl::event> host_task_events;
360367
host_task_events.reserve(2);
361368

0 commit comments

Comments
 (0)