Skip to content

Commit 8ed4fdf

Browse files
committed
Corrected boolean indexing cumsum
- The cumulative sum was being calculated incorrectly -- the offset from stride simplification was unused and the result was incorrect for some cases with negative strides
1 parent cf4660d commit 8ed4fdf

File tree

2 files changed

+8
-29
lines changed

2 files changed

+8
-29
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,6 @@ typedef size_t (*mask_positions_strided_impl_fn_ptr_t)(
424424
size_t,
425425
const char *,
426426
int,
427-
py::ssize_t,
428427
const py::ssize_t *,
429428
char *,
430429
std::vector<sycl::event> const &);
@@ -434,7 +433,6 @@ size_t mask_positions_strided_impl(sycl::queue q,
434433
size_t n_elems,
435434
const char *mask,
436435
int nd,
437-
py::ssize_t input_offset,
438436
const py::ssize_t *shape_strides,
439437
char *cumsum,
440438
std::vector<sycl::event> const &depends = {})
@@ -444,7 +442,7 @@ size_t mask_positions_strided_impl(sycl::queue q,
444442
cumsumT *cumsum_data_ptr = reinterpret_cast<cumsumT *>(cumsum);
445443
size_t wg_size = 128;
446444

447-
StridedIndexer strided_indexer{nd, input_offset, shape_strides};
445+
StridedIndexer strided_indexer{nd, 0, shape_strides};
448446
NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
449447

450448
sycl::event comp_ev =

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -201,46 +201,27 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
201201
return fn(exec_q, mask_size, mask_data, cumsum_data, depends);
202202
}
203203

204-
const py::ssize_t *shape = mask.get_shape_raw();
205-
auto const &strides_vector = mask.get_strides_vector();
206-
207-
using shT = std::vector<py::ssize_t>;
208-
shT simplified_shape;
209-
shT simplified_strides;
210-
py::ssize_t offset(0);
211-
204+
// Strided implementation
212205
int mask_nd = mask.get_ndim();
213-
int nd = mask_nd;
214-
215-
dpctl::tensor::py_internal::simplify_iteration_space_1(
216-
nd, shape, strides_vector, simplified_shape, simplified_strides,
217-
offset);
218-
219-
if (nd == 1 && simplified_strides[0] == 1) {
220-
auto fn = (use_i32)
221-
? mask_positions_contig_i32_dispatch_vector[mask_typeid]
222-
: mask_positions_contig_i64_dispatch_vector[mask_typeid];
223-
224-
return fn(exec_q, mask_size, mask_data, cumsum_data, depends);
225-
}
206+
auto const &shape_vector = mask.get_shape_vector();
207+
auto const &strides_vector = mask.get_strides_vector();
226208

227-
// Strided implementation
228209
auto strided_fn =
229210
(use_i32) ? mask_positions_strided_i32_dispatch_vector[mask_typeid]
230211
: mask_positions_strided_i64_dispatch_vector[mask_typeid];
231-
std::vector<sycl::event> host_task_events;
232212

213+
std::vector<sycl::event> host_task_events;
233214
using dpctl::tensor::offset_utils::device_allocate_and_pack;
234215
const auto &ptr_size_event_tuple = device_allocate_and_pack<py::ssize_t>(
235-
exec_q, host_task_events, simplified_shape, simplified_strides);
216+
exec_q, host_task_events, shape_vector, strides_vector);
236217
py::ssize_t *shape_strides = std::get<0>(ptr_size_event_tuple);
237218
if (shape_strides == nullptr) {
238219
sycl::event::wait(host_task_events);
239220
throw std::runtime_error("Unexpected error");
240221
}
241222
sycl::event copy_shape_ev = std::get<2>(ptr_size_event_tuple);
242223

243-
if (2 * static_cast<size_t>(nd) != std::get<1>(ptr_size_event_tuple)) {
224+
if (2 * static_cast<size_t>(mask_nd) != std::get<1>(ptr_size_event_tuple)) {
244225
copy_shape_ev.wait();
245226
sycl::event::wait(host_task_events);
246227
sycl::free(shape_strides, exec_q);
@@ -253,7 +234,7 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
253234
dependent_events.insert(dependent_events.end(), depends.begin(),
254235
depends.end());
255236

256-
size_t total_set = strided_fn(exec_q, mask_size, mask_data, nd, offset,
237+
size_t total_set = strided_fn(exec_q, mask_size, mask_data, mask_nd,
257238
shape_strides, cumsum_data, dependent_events);
258239

259240
sycl::event::wait(host_task_events);

0 commit comments

Comments
 (0)