Skip to content

Commit 4f16b93

Browse files
committed
Refactored host tasks used in boolean indexing
- Using a single host_task has shown to improve stability on CPU
1 parent 1320d39 commit 4f16b93

File tree

1 file changed

+43
-64
lines changed

1 file changed

+43
-64
lines changed

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 43 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -473,36 +473,27 @@ py_extract(dpctl::tensor::usm_ndarray src,
473473
const auto &ptr_size_event_tuple1 =
474474
device_allocate_and_pack<py::ssize_t>(
475475
exec_q, host_task_events, simplified_ortho_shape,
476-
simplified_ortho_src_strides, simplified_ortho_dst_strides);
477-
py::ssize_t *packed_ortho_src_dst_shape_strides =
478-
std::get<0>(ptr_size_event_tuple1);
479-
if (packed_ortho_src_dst_shape_strides == nullptr) {
476+
simplified_ortho_src_strides, simplified_ortho_dst_strides,
477+
masked_src_shape, masked_src_strides);
478+
py::ssize_t *packed_shapes_strides = std::get<0>(ptr_size_event_tuple1);
479+
if (packed_shapes_strides == nullptr) {
480480
throw std::runtime_error("Unable to allocate device memory");
481481
}
482-
sycl::event copy_shape_strides_ev1 = std::get<2>(ptr_size_event_tuple1);
482+
sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1);
483483

484-
const auto &ptr_size_event_tuple2 =
485-
device_allocate_and_pack<py::ssize_t>(
486-
exec_q, host_task_events, masked_src_shape, masked_src_strides);
484+
py::ssize_t *packed_ortho_src_dst_shape_strides = packed_shapes_strides;
487485
py::ssize_t *packed_masked_src_shape_strides =
488-
std::get<0>(ptr_size_event_tuple2);
489-
if (packed_masked_src_shape_strides == nullptr) {
490-
copy_shape_strides_ev1.wait();
491-
sycl::free(packed_ortho_src_dst_shape_strides, exec_q);
492-
throw std::runtime_error("Unable to allocate device memory");
493-
}
494-
sycl::event copy_shape_strides_ev2 = std::get<2>(ptr_size_event_tuple2);
486+
packed_shapes_strides + (3 * ortho_nd);
495487

496488
assert(masked_dst_shape.size() == 1);
497489
assert(masked_dst_strides.size() == 1);
498490

499491
std::vector<sycl::event> all_deps;
500-
all_deps.reserve(depends.size() + 2);
492+
all_deps.reserve(depends.size() + 1);
501493
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
502-
all_deps.push_back(copy_shape_strides_ev1);
503-
all_deps.push_back(copy_shape_strides_ev2);
494+
all_deps.push_back(copy_shapes_strides_ev);
504495

505-
assert(all_deps.size() == depends.size() + 2);
496+
assert(all_deps.size() == depends.size() + 1);
506497

507498
// OrthogIndexerT orthog_src_dst_indexer_, MaskedIndexerT
508499
// masked_src_indexer_, MaskedIndexerT masked_dst_indexer_
@@ -520,10 +511,8 @@ py_extract(dpctl::tensor::usm_ndarray src,
520511
exec_q.submit([&](sycl::handler &cgh) {
521512
cgh.depends_on(extract_ev);
522513
auto ctx = exec_q.get_context();
523-
cgh.host_task([ctx, packed_ortho_src_dst_shape_strides,
524-
packed_masked_src_shape_strides] {
525-
sycl::free(packed_ortho_src_dst_shape_strides, ctx);
526-
sycl::free(packed_masked_src_shape_strides, ctx);
514+
cgh.host_task([ctx, packed_shapes_strides] {
515+
sycl::free(packed_shapes_strides, ctx);
527516
});
528517
});
529518
host_task_events.push_back(cleanup_tmp_allocations_ev);
@@ -684,7 +673,7 @@ py_place(dpctl::tensor::usm_ndarray dst,
684673
auto rhs_shape_vec = rhs.get_shape_vector();
685674
auto rhs_strides_vec = rhs.get_strides_vector();
686675

687-
sycl::event extract_ev;
676+
sycl::event place_ev;
688677
std::vector<sycl::event> host_task_events{};
689678
if (axis_start == 0 && axis_end == dst_nd) {
690679
// empty orthogonal directions
@@ -713,13 +702,13 @@ py_place(dpctl::tensor::usm_ndarray dst,
713702

714703
assert(all_deps.size() == depends.size() + 1);
715704

716-
extract_ev = fn(exec_q, cumsum_sz, dst_data_p, cumsum_data_p,
717-
rhs_data_p, dst_nd, packed_dst_shape_strides,
718-
rhs_shape_vec[0], rhs_strides_vec[0], all_deps);
705+
place_ev = fn(exec_q, cumsum_sz, dst_data_p, cumsum_data_p, rhs_data_p,
706+
dst_nd, packed_dst_shape_strides, rhs_shape_vec[0],
707+
rhs_strides_vec[0], all_deps);
719708

720709
sycl::event cleanup_tmp_allocations_ev =
721710
exec_q.submit([&](sycl::handler &cgh) {
722-
cgh.depends_on(extract_ev);
711+
cgh.depends_on(place_ev);
723712
auto ctx = exec_q.get_context();
724713
cgh.host_task([ctx, packed_dst_shape_strides] {
725714
sycl::free(packed_dst_shape_strides, ctx);
@@ -778,65 +767,55 @@ py_place(dpctl::tensor::usm_ndarray dst,
778767
const auto &ptr_size_event_tuple1 =
779768
device_allocate_and_pack<py::ssize_t>(
780769
exec_q, host_task_events, simplified_ortho_shape,
781-
simplified_ortho_dst_strides, simplified_ortho_rhs_strides);
782-
py::ssize_t *packed_ortho_dst_rhs_shape_strides =
783-
std::get<0>(ptr_size_event_tuple1);
784-
if (packed_ortho_dst_rhs_shape_strides == nullptr) {
770+
simplified_ortho_dst_strides, simplified_ortho_rhs_strides,
771+
masked_dst_shape, masked_dst_strides);
772+
py::ssize_t *packed_shapes_strides = std::get<0>(ptr_size_event_tuple1);
773+
if (packed_shapes_strides == nullptr) {
785774
throw std::runtime_error("Unable to allocate device memory");
786775
}
787-
sycl::event copy_shape_strides_ev1 = std::get<2>(ptr_size_event_tuple1);
776+
sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1);
788777

789-
auto ptr_size_event_tuple2 = device_allocate_and_pack<py::ssize_t>(
790-
exec_q, host_task_events, masked_dst_shape, masked_dst_strides);
778+
py::ssize_t *packed_ortho_dst_rhs_shape_strides = packed_shapes_strides;
791779
py::ssize_t *packed_masked_dst_shape_strides =
792-
std::get<0>(ptr_size_event_tuple2);
793-
if (packed_masked_dst_shape_strides == nullptr) {
794-
copy_shape_strides_ev1.wait();
795-
sycl::free(packed_ortho_dst_rhs_shape_strides, exec_q);
796-
throw std::runtime_error("Unable to allocate device memory");
797-
}
798-
sycl::event copy_shape_strides_ev2 = std::get<2>(ptr_size_event_tuple2);
780+
packed_shapes_strides + (3 * ortho_nd);
799781

800782
assert(masked_rhs_shape.size() == 1);
801783
assert(masked_rhs_strides.size() == 1);
802784

803785
std::vector<sycl::event> all_deps;
804-
all_deps.reserve(depends.size() + 2);
786+
all_deps.reserve(depends.size() + 1);
805787
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
806-
all_deps.push_back(copy_shape_strides_ev1);
807-
all_deps.push_back(copy_shape_strides_ev2);
808-
809-
assert(all_deps.size() == depends.size() + 2);
810-
811-
extract_ev = fn(exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
812-
cumsum_data_p, rhs_data_p,
813-
// data to build orthog_dst_rhs_indexer
814-
ortho_nd, packed_ortho_dst_rhs_shape_strides,
815-
ortho_dst_offset, ortho_rhs_offset,
816-
// data to build masked_dst_indexer
817-
masked_dst_nd, packed_masked_dst_shape_strides,
818-
// data to build masked_dst_indexer,
819-
masked_rhs_shape[0], masked_rhs_strides[0], all_deps);
788+
all_deps.push_back(copy_shapes_strides_ev);
789+
790+
assert(all_deps.size() == depends.size() + 1);
791+
792+
place_ev = fn(exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
793+
cumsum_data_p, rhs_data_p,
794+
// data to build orthog_dst_rhs_indexer
795+
ortho_nd, packed_ortho_dst_rhs_shape_strides,
796+
ortho_dst_offset, ortho_rhs_offset,
797+
// data to build masked_dst_indexer
798+
masked_dst_nd, packed_masked_dst_shape_strides,
799+
// data to build masked_dst_indexer,
800+
masked_rhs_shape[0], masked_rhs_strides[0], all_deps);
820801

821802
sycl::event cleanup_tmp_allocations_ev =
822803
exec_q.submit([&](sycl::handler &cgh) {
823-
cgh.depends_on(extract_ev);
804+
cgh.depends_on(place_ev);
824805
auto ctx = exec_q.get_context();
825-
cgh.host_task([ctx, packed_ortho_dst_rhs_shape_strides,
826-
packed_masked_dst_shape_strides] {
827-
sycl::free(packed_ortho_dst_rhs_shape_strides, ctx);
828-
sycl::free(packed_masked_dst_shape_strides, ctx);
806+
cgh.host_task([ctx, packed_shapes_strides] {
807+
sycl::free(packed_shapes_strides, ctx);
829808
});
830809
});
831810
host_task_events.push_back(cleanup_tmp_allocations_ev);
832811
}
833812

834-
host_task_events.push_back(extract_ev);
813+
host_task_events.push_back(place_ev);
835814

836815
sycl::event py_obj_management_host_task_ev = dpctl::utils::keep_args_alive(
837816
exec_q, {dst, cumsum, rhs}, host_task_events);
838817

839-
return std::make_pair(py_obj_management_host_task_ev, extract_ev);
818+
return std::make_pair(py_obj_management_host_task_ev, place_ev);
840819
}
841820

842821
// Non-zero

0 commit comments

Comments
 (0)