Skip to content

Refactored boolean advanced indexing host tasks #1207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 14, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 43 additions & 64 deletions dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,36 +473,27 @@ py_extract(dpctl::tensor::usm_ndarray src,
const auto &ptr_size_event_tuple1 =
device_allocate_and_pack<py::ssize_t>(
exec_q, host_task_events, simplified_ortho_shape,
simplified_ortho_src_strides, simplified_ortho_dst_strides);
py::ssize_t *packed_ortho_src_dst_shape_strides =
std::get<0>(ptr_size_event_tuple1);
if (packed_ortho_src_dst_shape_strides == nullptr) {
simplified_ortho_src_strides, simplified_ortho_dst_strides,
masked_src_shape, masked_src_strides);
py::ssize_t *packed_shapes_strides = std::get<0>(ptr_size_event_tuple1);
if (packed_shapes_strides == nullptr) {
throw std::runtime_error("Unable to allocate device memory");
}
sycl::event copy_shape_strides_ev1 = std::get<2>(ptr_size_event_tuple1);
sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1);

const auto &ptr_size_event_tuple2 =
device_allocate_and_pack<py::ssize_t>(
exec_q, host_task_events, masked_src_shape, masked_src_strides);
py::ssize_t *packed_ortho_src_dst_shape_strides = packed_shapes_strides;
py::ssize_t *packed_masked_src_shape_strides =
std::get<0>(ptr_size_event_tuple2);
if (packed_masked_src_shape_strides == nullptr) {
copy_shape_strides_ev1.wait();
sycl::free(packed_ortho_src_dst_shape_strides, exec_q);
throw std::runtime_error("Unable to allocate device memory");
}
sycl::event copy_shape_strides_ev2 = std::get<2>(ptr_size_event_tuple2);
packed_shapes_strides + (3 * ortho_nd);

assert(masked_dst_shape.size() == 1);
assert(masked_dst_strides.size() == 1);

std::vector<sycl::event> all_deps;
all_deps.reserve(depends.size() + 2);
all_deps.reserve(depends.size() + 1);
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
all_deps.push_back(copy_shape_strides_ev1);
all_deps.push_back(copy_shape_strides_ev2);
all_deps.push_back(copy_shapes_strides_ev);

assert(all_deps.size() == depends.size() + 2);
assert(all_deps.size() == depends.size() + 1);

// OrthogIndexerT orthog_src_dst_indexer_, MaskedIndexerT
// masked_src_indexer_, MaskedIndexerT masked_dst_indexer_
Expand All @@ -520,10 +511,8 @@ py_extract(dpctl::tensor::usm_ndarray src,
exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(extract_ev);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, packed_ortho_src_dst_shape_strides,
packed_masked_src_shape_strides] {
sycl::free(packed_ortho_src_dst_shape_strides, ctx);
sycl::free(packed_masked_src_shape_strides, ctx);
cgh.host_task([ctx, packed_shapes_strides] {
sycl::free(packed_shapes_strides, ctx);
});
});
host_task_events.push_back(cleanup_tmp_allocations_ev);
Expand Down Expand Up @@ -684,7 +673,7 @@ py_place(dpctl::tensor::usm_ndarray dst,
auto rhs_shape_vec = rhs.get_shape_vector();
auto rhs_strides_vec = rhs.get_strides_vector();

sycl::event extract_ev;
sycl::event place_ev;
std::vector<sycl::event> host_task_events{};
if (axis_start == 0 && axis_end == dst_nd) {
// empty orthogonal directions
Expand Down Expand Up @@ -713,13 +702,13 @@ py_place(dpctl::tensor::usm_ndarray dst,

assert(all_deps.size() == depends.size() + 1);

extract_ev = fn(exec_q, cumsum_sz, dst_data_p, cumsum_data_p,
rhs_data_p, dst_nd, packed_dst_shape_strides,
rhs_shape_vec[0], rhs_strides_vec[0], all_deps);
place_ev = fn(exec_q, cumsum_sz, dst_data_p, cumsum_data_p, rhs_data_p,
dst_nd, packed_dst_shape_strides, rhs_shape_vec[0],
rhs_strides_vec[0], all_deps);

sycl::event cleanup_tmp_allocations_ev =
exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(extract_ev);
cgh.depends_on(place_ev);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, packed_dst_shape_strides] {
sycl::free(packed_dst_shape_strides, ctx);
Expand Down Expand Up @@ -778,65 +767,55 @@ py_place(dpctl::tensor::usm_ndarray dst,
const auto &ptr_size_event_tuple1 =
device_allocate_and_pack<py::ssize_t>(
exec_q, host_task_events, simplified_ortho_shape,
simplified_ortho_dst_strides, simplified_ortho_rhs_strides);
py::ssize_t *packed_ortho_dst_rhs_shape_strides =
std::get<0>(ptr_size_event_tuple1);
if (packed_ortho_dst_rhs_shape_strides == nullptr) {
simplified_ortho_dst_strides, simplified_ortho_rhs_strides,
masked_dst_shape, masked_dst_strides);
py::ssize_t *packed_shapes_strides = std::get<0>(ptr_size_event_tuple1);
if (packed_shapes_strides == nullptr) {
throw std::runtime_error("Unable to allocate device memory");
}
sycl::event copy_shape_strides_ev1 = std::get<2>(ptr_size_event_tuple1);
sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1);

auto ptr_size_event_tuple2 = device_allocate_and_pack<py::ssize_t>(
exec_q, host_task_events, masked_dst_shape, masked_dst_strides);
py::ssize_t *packed_ortho_dst_rhs_shape_strides = packed_shapes_strides;
py::ssize_t *packed_masked_dst_shape_strides =
std::get<0>(ptr_size_event_tuple2);
if (packed_masked_dst_shape_strides == nullptr) {
copy_shape_strides_ev1.wait();
sycl::free(packed_ortho_dst_rhs_shape_strides, exec_q);
throw std::runtime_error("Unable to allocate device memory");
}
sycl::event copy_shape_strides_ev2 = std::get<2>(ptr_size_event_tuple2);
packed_shapes_strides + (3 * ortho_nd);

assert(masked_rhs_shape.size() == 1);
assert(masked_rhs_strides.size() == 1);

std::vector<sycl::event> all_deps;
all_deps.reserve(depends.size() + 2);
all_deps.reserve(depends.size() + 1);
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
all_deps.push_back(copy_shape_strides_ev1);
all_deps.push_back(copy_shape_strides_ev2);

assert(all_deps.size() == depends.size() + 2);

extract_ev = fn(exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
cumsum_data_p, rhs_data_p,
// data to build orthog_dst_rhs_indexer
ortho_nd, packed_ortho_dst_rhs_shape_strides,
ortho_dst_offset, ortho_rhs_offset,
// data to build masked_dst_indexer
masked_dst_nd, packed_masked_dst_shape_strides,
// data to build masked_dst_indexer,
masked_rhs_shape[0], masked_rhs_strides[0], all_deps);
all_deps.push_back(copy_shapes_strides_ev);

assert(all_deps.size() == depends.size() + 1);

place_ev = fn(exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
cumsum_data_p, rhs_data_p,
// data to build orthog_dst_rhs_indexer
ortho_nd, packed_ortho_dst_rhs_shape_strides,
ortho_dst_offset, ortho_rhs_offset,
// data to build masked_dst_indexer
masked_dst_nd, packed_masked_dst_shape_strides,
// data to build masked_dst_indexer,
masked_rhs_shape[0], masked_rhs_strides[0], all_deps);

sycl::event cleanup_tmp_allocations_ev =
exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(extract_ev);
cgh.depends_on(place_ev);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, packed_ortho_dst_rhs_shape_strides,
packed_masked_dst_shape_strides] {
sycl::free(packed_ortho_dst_rhs_shape_strides, ctx);
sycl::free(packed_masked_dst_shape_strides, ctx);
cgh.host_task([ctx, packed_shapes_strides] {
sycl::free(packed_shapes_strides, ctx);
});
});
host_task_events.push_back(cleanup_tmp_allocations_ev);
}

host_task_events.push_back(extract_ev);
host_task_events.push_back(place_ev);

sycl::event py_obj_management_host_task_ev = dpctl::utils::keep_args_alive(
exec_q, {dst, cumsum, rhs}, host_task_events);

return std::make_pair(py_obj_management_host_task_ev, extract_ev);
return std::make_pair(py_obj_management_host_task_ev, place_ev);
}

// Non-zero
Expand Down