@@ -79,7 +79,7 @@ using value_type_of_t = typename value_type_of<T>::type;
79
79
typedef sycl::event (*nan_to_num_fn_ptr_t )(sycl::queue &,
80
80
int ,
81
81
size_t ,
82
- py::ssize_t *,
82
+ const py::ssize_t *,
83
83
const py::object &,
84
84
const py::object &,
85
85
const py::object &,
@@ -93,7 +93,7 @@ template <typename T>
93
93
sycl::event nan_to_num_call (sycl::queue &exec_q,
94
94
int nd,
95
95
size_t nelems,
96
- py::ssize_t *shape_strides,
96
+ const py::ssize_t *shape_strides,
97
97
const py::object &py_nan,
98
98
const py::object &py_posinf,
99
99
const py::object &py_neginf,
@@ -302,15 +302,12 @@ std::pair<sycl::event, sycl::event>
302
302
std::vector<sycl::event> host_tasks{};
303
303
host_tasks.reserve (2 );
304
304
305
- const auto & ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t >(
305
+ auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t >(
306
306
q, host_tasks, simplified_shape, simplified_src_strides,
307
307
simplified_dst_strides);
308
- py:: ssize_t *shape_strides = std::get<0 >(ptr_size_event_triple_);
308
+ auto shape_strides_owner = std::move (std:: get<0 >(ptr_size_event_triple_) );
309
309
const sycl::event ©_shape_ev = std::get<2 >(ptr_size_event_triple_);
310
-
311
- if (shape_strides == nullptr ) {
312
- throw std::runtime_error (" Device memory allocation failed" );
313
- }
310
+ const py::ssize_t *shape_strides = shape_strides_owner.get ();
314
311
315
312
std::vector<sycl::event> all_deps;
316
313
all_deps.reserve (depends.size () + 1 );
@@ -322,13 +319,9 @@ std::pair<sycl::event, sycl::event>
322
319
src_offset, dst_data, dst_offset, all_deps);
323
320
324
321
// async free of shape_strides temporary
325
- auto ctx = q.get_context ();
326
- sycl::event tmp_cleanup_ev = q.submit ([&](sycl::handler &cgh) {
327
- cgh.depends_on (comp_ev);
328
- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
329
- cgh.host_task (
330
- [ctx, shape_strides]() { sycl_free_noexcept (shape_strides, ctx); });
331
- });
322
+ sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
323
+ q, {comp_ev}, shape_strides_owner);
324
+
332
325
host_tasks.push_back (tmp_cleanup_ev);
333
326
334
327
return std::make_pair (
0 commit comments