@@ -340,22 +340,29 @@ py_searchsorted(const dpctl::tensor::usm_ndarray &hay,
340
340
int simplified_nd = needles_nd;
341
341
342
342
using shT = std::vector<py::ssize_t >;
343
-
344
343
shT simplified_common_shape;
345
344
shT simplified_needles_strides;
346
345
shT simplified_positions_strides;
347
346
py::ssize_t needles_offset (0 );
348
347
py::ssize_t positions_offset (0 );
349
348
350
- dpctl::tensor::py_internal::simplify_iteration_space (
351
- // modified by refernce
352
- simplified_nd,
353
- // read-only inputs
354
- needles_shape_ptr, needles_strides, positions_strides,
355
- // output, modified by reference
356
- simplified_common_shape, simplified_needles_strides,
357
- simplified_positions_strides, needles_offset, positions_offset);
358
-
349
+ if (simplified_nd == 0 ) {
350
+ // needles and positions have same nd
351
+ simplified_nd = 1 ;
352
+ simplified_common_shape.push_back (1 );
353
+ simplified_needles_strides.push_back (0 );
354
+ simplified_positions_strides.push_back (0 );
355
+ }
356
+ else {
357
+ dpctl::tensor::py_internal::simplify_iteration_space (
358
+ // modified by refernce
359
+ simplified_nd,
360
+ // read-only inputs
361
+ needles_shape_ptr, needles_strides, positions_strides,
362
+ // output, modified by reference
363
+ simplified_common_shape, simplified_needles_strides,
364
+ simplified_positions_strides, needles_offset, positions_offset);
365
+ }
359
366
std::vector<sycl::event> host_task_events;
360
367
host_task_events.reserve (2 );
361
368
0 commit comments