Skip to content

Commit 0c6d3f8

Browse files
committed
Align with changes to device_allocate_and_pack in dpctl
1 parent 9b70c28 commit 0c6d3f8

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ using value_type_of_t = typename value_type_of<T>::type;
7979
typedef sycl::event (*nan_to_num_fn_ptr_t)(sycl::queue &,
8080
int,
8181
size_t,
82-
py::ssize_t *,
82+
const py::ssize_t *,
8383
const py::object &,
8484
const py::object &,
8585
const py::object &,
@@ -93,7 +93,7 @@ template <typename T>
9393
sycl::event nan_to_num_call(sycl::queue &exec_q,
9494
int nd,
9595
size_t nelems,
96-
py::ssize_t *shape_strides,
96+
const py::ssize_t *shape_strides,
9797
const py::object &py_nan,
9898
const py::object &py_posinf,
9999
const py::object &py_neginf,
@@ -302,15 +302,12 @@ std::pair<sycl::event, sycl::event>
302302
std::vector<sycl::event> host_tasks{};
303303
host_tasks.reserve(2);
304304

305-
const auto &ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
305+
auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
306306
q, host_tasks, simplified_shape, simplified_src_strides,
307307
simplified_dst_strides);
308-
py::ssize_t *shape_strides = std::get<0>(ptr_size_event_triple_);
308+
auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
309309
const sycl::event &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
310-
311-
if (shape_strides == nullptr) {
312-
throw std::runtime_error("Device memory allocation failed");
313-
}
310+
const py::ssize_t *shape_strides = shape_strides_owner.get();
314311

315312
std::vector<sycl::event> all_deps;
316313
all_deps.reserve(depends.size() + 1);
@@ -322,13 +319,9 @@ std::pair<sycl::event, sycl::event>
322319
src_offset, dst_data, dst_offset, all_deps);
323320

324321
// async free of shape_strides temporary
325-
auto ctx = q.get_context();
326-
sycl::event tmp_cleanup_ev = q.submit([&](sycl::handler &cgh) {
327-
cgh.depends_on(comp_ev);
328-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
329-
cgh.host_task(
330-
[ctx, shape_strides]() { sycl_free_noexcept(shape_strides, ctx); });
331-
});
322+
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
323+
q, {comp_ev}, shape_strides_owner);
324+
332325
host_tasks.push_back(tmp_cleanup_ev);
333326

334327
return std::make_pair(

0 commit comments

Comments
 (0)