Skip to content

Commit 08b6dd0

Browse files
Fixed memory leak introduced in new methods of usm_ndarray
1 parent c85571b commit 08b6dd0

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ class usm_memory : public py::object
793793
return bool(opaque_ptr);
794794
}
795795

796-
std::shared_ptr<void> get_smart_ptr_owner() const
796+
const std::shared_ptr<void> &get_smart_ptr_owner() const
797797
{
798798
auto const &api = ::dpctl::detail::dpctl_capi::get();
799799
Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
@@ -1114,17 +1114,20 @@ class usm_ndarray : public py::object
11141114
auto const &api = ::dpctl::detail::dpctl_capi::get();
11151115
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
11161116

1117-
if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_))
1117+
if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1118+
Py_DECREF(usm_data);
11181119
return false;
1120+
}
11191121

11201122
Py_MemoryObject *mem_obj =
11211123
reinterpret_cast<Py_MemoryObject *>(usm_data);
11221124
const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
11231125

1126+
Py_DECREF(usm_data);
11241127
return bool(opaque_ptr);
11251128
}
11261129

1127-
std::shared_ptr<void> get_smart_ptr_owner() const
1130+
const std::shared_ptr<void> &get_smart_ptr_owner() const
11281131
{
11291132
PyUSMArrayObject *raw_ar = usm_array_ptr();
11301133

@@ -1133,6 +1136,7 @@ class usm_ndarray : public py::object
11331136
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
11341137

11351138
if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1139+
Py_DECREF(usm_data);
11361140
throw std::runtime_error(
11371141
"usm_ndarray object does not have Memory object "
11381142
"managing lifetime of USM allocation");
@@ -1141,6 +1145,7 @@ class usm_ndarray : public py::object
11411145
Py_MemoryObject *mem_obj =
11421146
reinterpret_cast<Py_MemoryObject *>(usm_data);
11431147
void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1148+
Py_DECREF(usm_data);
11441149

11451150
if (opaque_ptr) {
11461151
auto shptr_ptr =
@@ -1172,28 +1177,32 @@ namespace detail
11721177
struct ManagedMemory
11731178
{
11741179

1175-
static bool is_usm_managed_by_shared_ptr(const py::handle &h)
1180+
static bool is_usm_managed_by_shared_ptr(const py::object &h)
11761181
{
11771182
if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1178-
auto usm_memory_inst = py::cast<dpctl::memory::usm_memory>(h);
1183+
const auto &usm_memory_inst =
1184+
py::cast<dpctl::memory::usm_memory>(h);
11791185
return usm_memory_inst.is_managed_by_smart_ptr();
11801186
}
11811187
else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1182-
auto usm_array_inst = py::cast<dpctl::tensor::usm_ndarray>(h);
1188+
const auto &usm_array_inst =
1189+
py::cast<dpctl::tensor::usm_ndarray>(h);
11831190
return usm_array_inst.is_managed_by_smart_ptr();
11841191
}
11851192

11861193
return false;
11871194
}
11881195

1189-
static std::shared_ptr<void> extract_shared_ptr(const py::handle &h)
1196+
static const std::shared_ptr<void> &extract_shared_ptr(const py::object &h)
11901197
{
11911198
if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1192-
auto usm_memory_inst = py::cast<dpctl::memory::usm_memory>(h);
1199+
const auto &usm_memory_inst =
1200+
py::cast<dpctl::memory::usm_memory>(h);
11931201
return usm_memory_inst.get_smart_ptr_owner();
11941202
}
11951203
else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1196-
auto usm_array_inst = py::cast<dpctl::tensor::usm_ndarray>(h);
1204+
const auto &usm_array_inst =
1205+
py::cast<dpctl::tensor::usm_ndarray>(h);
11971206
return usm_array_inst.get_smart_ptr_owner();
11981207
}
11991208

@@ -1216,10 +1225,11 @@ sycl::event keep_args_alive(sycl::queue &q,
12161225
std::array<std::shared_ptr<void>, num> shp_usm{};
12171226

12181227
for (std::size_t i = 0; i < num; ++i) {
1219-
auto py_obj_i = py_objs[i];
1228+
const auto &py_obj_i = py_objs[i];
12201229
if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1221-
shp_usm[n_usm_owners_held] =
1230+
const auto &shp =
12221231
detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1232+
shp_usm[n_usm_owners_held] = shp;
12231233
++n_usm_owners_held;
12241234
}
12251235
else {

0 commit comments

Comments
 (0)