@@ -793,7 +793,7 @@ class usm_memory : public py::object
793
793
return bool (opaque_ptr );
794
794
}
795
795
796
- std ::shared_ptr < void > get_smart_ptr_owner () const
796
+ const std ::shared_ptr < void > & get_smart_ptr_owner () const
797
797
{
798
798
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
799
799
Py_MemoryObject * mem_obj = reinterpret_cast < Py_MemoryObject * > (m_ptr );
@@ -1114,17 +1114,20 @@ class usm_ndarray : public py::object
1114
1114
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
1115
1115
PyObject * usm_data = api .UsmNDArray_GetUSMData_ (raw_ar );
1116
1116
1117
- if (!PyObject_TypeCheck (usm_data , api .Py_MemoryType_ ))
1117
+ if (!PyObject_TypeCheck (usm_data , api .Py_MemoryType_ )) {
1118
+ Py_DECREF (usm_data );
1118
1119
return false;
1120
+ }
1119
1121
1120
1122
Py_MemoryObject * mem_obj =
1121
1123
reinterpret_cast < Py_MemoryObject * > (usm_data );
1122
1124
const void * opaque_ptr = api .Memory_GetOpaquePointer_ (mem_obj );
1123
1125
1126
+ Py_DECREF (usm_data );
1124
1127
return bool (opaque_ptr );
1125
1128
}
1126
1129
1127
- std ::shared_ptr < void > get_smart_ptr_owner () const
1130
+ const std ::shared_ptr < void > & get_smart_ptr_owner () const
1128
1131
{
1129
1132
PyUSMArrayObject * raw_ar = usm_array_ptr ();
1130
1133
@@ -1133,6 +1136,7 @@ class usm_ndarray : public py::object
1133
1136
PyObject * usm_data = api .UsmNDArray_GetUSMData_ (raw_ar );
1134
1137
1135
1138
if (!PyObject_TypeCheck (usm_data , api .Py_MemoryType_ )) {
1139
+ Py_DECREF (usm_data );
1136
1140
throw std ::runtime_error (
1137
1141
"usm_ndarray object does not have Memory object "
1138
1142
"managing lifetime of USM allocation" );
@@ -1141,6 +1145,7 @@ class usm_ndarray : public py::object
1141
1145
Py_MemoryObject * mem_obj =
1142
1146
reinterpret_cast < Py_MemoryObject * > (usm_data );
1143
1147
void * opaque_ptr = api .Memory_GetOpaquePointer_ (mem_obj );
1148
+ Py_DECREF (usm_data );
1144
1149
1145
1150
if (opaque_ptr ) {
1146
1151
auto shptr_ptr =
@@ -1172,28 +1177,32 @@ namespace detail
1172
1177
struct ManagedMemory
1173
1178
{
1174
1179
1175
- static bool is_usm_managed_by_shared_ptr (const py ::handle & h )
1180
+ static bool is_usm_managed_by_shared_ptr (const py ::object & h )
1176
1181
{
1177
1182
if (py ::isinstance < dpctl ::memory ::usm_memory > (h )) {
1178
- auto usm_memory_inst = py ::cast < dpctl ::memory ::usm_memory > (h );
1183
+ const auto & usm_memory_inst =
1184
+ py ::cast < dpctl ::memory ::usm_memory > (h );
1179
1185
return usm_memory_inst .is_managed_by_smart_ptr ();
1180
1186
}
1181
1187
else if (py ::isinstance < dpctl ::tensor ::usm_ndarray > (h )) {
1182
- auto usm_array_inst = py ::cast < dpctl ::tensor ::usm_ndarray > (h );
1188
+ const auto & usm_array_inst =
1189
+ py ::cast < dpctl ::tensor ::usm_ndarray > (h );
1183
1190
return usm_array_inst .is_managed_by_smart_ptr ();
1184
1191
}
1185
1192
1186
1193
return false;
1187
1194
}
1188
1195
1189
- static std ::shared_ptr < void > extract_shared_ptr (const py ::handle & h )
1196
+ static const std ::shared_ptr < void > & extract_shared_ptr (const py ::object & h )
1190
1197
{
1191
1198
if (py ::isinstance < dpctl ::memory ::usm_memory > (h )) {
1192
- auto usm_memory_inst = py ::cast < dpctl ::memory ::usm_memory > (h );
1199
+ const auto & usm_memory_inst =
1200
+ py ::cast < dpctl ::memory ::usm_memory > (h );
1193
1201
return usm_memory_inst .get_smart_ptr_owner ();
1194
1202
}
1195
1203
else if (py ::isinstance < dpctl ::tensor ::usm_ndarray > (h )) {
1196
- auto usm_array_inst = py ::cast < dpctl ::tensor ::usm_ndarray > (h );
1204
+ const auto & usm_array_inst =
1205
+ py ::cast < dpctl ::tensor ::usm_ndarray > (h );
1197
1206
return usm_array_inst .get_smart_ptr_owner ();
1198
1207
}
1199
1208
@@ -1216,10 +1225,11 @@ sycl::event keep_args_alive(sycl::queue &q,
1216
1225
std ::array < std ::shared_ptr < void > , num > shp_usm {};
1217
1226
1218
1227
for (std ::size_t i = 0 ; i < num ; ++ i ) {
1219
- auto py_obj_i = py_objs [i ];
1228
+ const auto & py_obj_i = py_objs [i ];
1220
1229
if (detail ::ManagedMemory ::is_usm_managed_by_shared_ptr (py_obj_i )) {
1221
- shp_usm [ n_usm_owners_held ] =
1230
+ const auto & shp =
1222
1231
detail ::ManagedMemory ::extract_shared_ptr (py_obj_i );
1232
+ shp_usm [n_usm_owners_held ] = shp ;
1223
1233
++ n_usm_owners_held ;
1224
1234
}
1225
1235
else {
0 commit comments