@@ -201,46 +201,27 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
201
201
return fn (exec_q, mask_size, mask_data, cumsum_data, depends);
202
202
}
203
203
204
- const py::ssize_t *shape = mask.get_shape_raw ();
205
- auto const &strides_vector = mask.get_strides_vector ();
206
-
207
- using shT = std::vector<py::ssize_t >;
208
- shT simplified_shape;
209
- shT simplified_strides;
210
- py::ssize_t offset (0 );
211
-
204
+ // Strided implementation
212
205
int mask_nd = mask.get_ndim ();
213
- int nd = mask_nd;
214
-
215
- dpctl::tensor::py_internal::simplify_iteration_space_1 (
216
- nd, shape, strides_vector, simplified_shape, simplified_strides,
217
- offset);
218
-
219
- if (nd == 1 && simplified_strides[0 ] == 1 ) {
220
- auto fn = (use_i32)
221
- ? mask_positions_contig_i32_dispatch_vector[mask_typeid]
222
- : mask_positions_contig_i64_dispatch_vector[mask_typeid];
223
-
224
- return fn (exec_q, mask_size, mask_data, cumsum_data, depends);
225
- }
206
+ auto const &shape_vector = mask.get_shape_vector ();
207
+ auto const &strides_vector = mask.get_strides_vector ();
226
208
227
- // Strided implementation
228
209
auto strided_fn =
229
210
(use_i32) ? mask_positions_strided_i32_dispatch_vector[mask_typeid]
230
211
: mask_positions_strided_i64_dispatch_vector[mask_typeid];
231
- std::vector<sycl::event> host_task_events;
232
212
213
+ std::vector<sycl::event> host_task_events;
233
214
using dpctl::tensor::offset_utils::device_allocate_and_pack;
234
215
const auto &ptr_size_event_tuple = device_allocate_and_pack<py::ssize_t >(
235
- exec_q, host_task_events, simplified_shape, simplified_strides );
216
+ exec_q, host_task_events, shape_vector, strides_vector );
236
217
py::ssize_t *shape_strides = std::get<0 >(ptr_size_event_tuple);
237
218
if (shape_strides == nullptr ) {
238
219
sycl::event::wait (host_task_events);
239
220
throw std::runtime_error (" Unexpected error" );
240
221
}
241
222
sycl::event copy_shape_ev = std::get<2 >(ptr_size_event_tuple);
242
223
243
- if (2 * static_cast <size_t >(nd ) != std::get<1 >(ptr_size_event_tuple)) {
224
+ if (2 * static_cast <size_t >(mask_nd ) != std::get<1 >(ptr_size_event_tuple)) {
244
225
copy_shape_ev.wait ();
245
226
sycl::event::wait (host_task_events);
246
227
sycl::free (shape_strides, exec_q);
@@ -253,7 +234,7 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
253
234
dependent_events.insert (dependent_events.end (), depends.begin (),
254
235
depends.end ());
255
236
256
- size_t total_set = strided_fn (exec_q, mask_size, mask_data, nd, offset ,
237
+ size_t total_set = strided_fn (exec_q, mask_size, mask_data, mask_nd ,
257
238
shape_strides, cumsum_data, dependent_events);
258
239
259
240
sycl::event::wait (host_task_events);
0 commit comments