Skip to content

Commit c3077da

Browse files
authored
Systematically replace __del__ with weakref.finalize() (#246)
* Systematically replace `__del__` with `weakref.finalize()` * Event._finalize() approach with self._finalizer.Detach() * Stream._MembersNeededForFinalize() approach. Corresponding demonstration of finalize behavior (immediate cleanup): https://github.com/rwgk/stuff/blob/f6fbd670b8376003c7767c96538d8ab0b1f49d96/random_attic/weakref_finalize_toy_example.py * Buffer._MembersNeededForFinalize() approach. * Apply _MembersNeededForFinalize pattern to _event.py * _module.py: simply keep TODO comment only * Apply _MembersNeededForFinalize pattern to _program.py
1 parent fd71ced commit c3077da

File tree

7 files changed

+124
-108
lines changed

7 files changed

+124
-108
lines changed

cuda_core/cuda/core/experimental/_event.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
import weakref
56
from dataclasses import dataclass
67
from typing import Optional
78

@@ -50,19 +51,29 @@ class Event:
5051
5152
"""
5253

53-
__slots__ = ("_handle", "_timing_disabled", "_busy_waited")
54+
class _MembersNeededForFinalize:
55+
__slots__ = ("handle",)
56+
57+
def __init__(self, event_obj, handle):
58+
self.handle = handle
59+
weakref.finalize(event_obj, self.close)
60+
61+
def close(self):
62+
if self.handle is not None:
63+
handle_return(cuda.cuEventDestroy(self.handle))
64+
self.handle = None
65+
66+
__slots__ = ("__weakref__", "_mnff", "_timing_disabled", "_busy_waited")
5467

5568
def __init__(self):
56-
self._handle = None
5769
raise NotImplementedError(
5870
"directly creating an Event object can be ambiguous. Please call call Stream.record()."
5971
)
6072

6173
@staticmethod
6274
def _init(options: Optional[EventOptions] = None):
6375
self = Event.__new__(Event)
64-
# minimal requirements for the destructor
65-
self._handle = None
76+
self._mnff = Event._MembersNeededForFinalize(self, None)
6677

6778
options = check_or_create_options(EventOptions, options, "Event options")
6879
flags = 0x0
@@ -76,18 +87,12 @@ def _init(options: Optional[EventOptions] = None):
7687
self._busy_waited = True
7788
if options.support_ipc:
7889
raise NotImplementedError("TODO")
79-
self._handle = handle_return(cuda.cuEventCreate(flags))
90+
self._mnff.handle = handle_return(cuda.cuEventCreate(flags))
8091
return self
8192

82-
def __del__(self):
83-
"""Return close(self)"""
84-
self.close()
85-
8693
def close(self):
8794
"""Destroy the event."""
88-
if self._handle:
89-
handle_return(cuda.cuEventDestroy(self._handle))
90-
self._handle = None
95+
self._mnff.close()
9196

9297
@property
9398
def is_timing_disabled(self) -> bool:
@@ -114,12 +119,12 @@ def sync(self):
114119
has been completed.
115120
116121
"""
117-
handle_return(cuda.cuEventSynchronize(self._handle))
122+
handle_return(cuda.cuEventSynchronize(self._mnff.handle))
118123

119124
@property
120125
def is_done(self) -> bool:
121126
"""Return True if all captured works have been completed, otherwise False."""
122-
(result,) = cuda.cuEventQuery(self._handle)
127+
(result,) = cuda.cuEventQuery(self._mnff.handle)
123128
if result == cuda.CUresult.CUDA_SUCCESS:
124129
return True
125130
elif result == cuda.CUresult.CUDA_ERROR_NOT_READY:
@@ -130,4 +135,4 @@ def is_done(self) -> bool:
130135
@property
131136
def handle(self) -> int:
132137
"""Return the underlying cudaEvent_t pointer address as Python int."""
133-
return int(self._handle)
138+
return int(self._mnff.handle)

cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ cdef class ParamHolder:
182182
for i, arg in enumerate(kernel_args):
183183
if isinstance(arg, Buffer):
184184
# we need the address of where the actual buffer address is stored
185-
self.data_addresses[i] = <void*><intptr_t>(arg._ptr.getPtr())
185+
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
186186
continue
187187
elif isinstance(arg, int):
188188
# Here's the dilemma: We want to have a fast path to pass in Python

cuda_core/cuda/core/experimental/_launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def launch(kernel, config, *kernel_args):
131131
drv_cfg = cuda.CUlaunchConfig()
132132
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
133133
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
134-
drv_cfg.hStream = config.stream._handle
134+
drv_cfg.hStream = config.stream.handle
135135
drv_cfg.sharedMemBytes = config.shmem_size
136136
drv_cfg.numAttrs = 0 # TODO
137137
handle_return(cuda.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0))

cuda_core/cuda/core/experimental/_memory.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import abc
8+
import weakref
89
from typing import Optional, Tuple, TypeVar
910

1011
from cuda import cuda
@@ -41,17 +42,28 @@ class Buffer:
4142
4243
"""
4344

45+
class _MembersNeededForFinalize:
46+
__slots__ = ("ptr", "size", "mr")
47+
48+
def __init__(self, buffer_obj, ptr, size, mr):
49+
self.ptr = ptr
50+
self.size = size
51+
self.mr = mr
52+
weakref.finalize(buffer_obj, self.close)
53+
54+
def close(self, stream=None):
55+
if self.ptr and self.mr is not None:
56+
if stream is None:
57+
stream = default_stream()
58+
self.mr.deallocate(self.ptr, self.size, stream)
59+
self.ptr = 0
60+
self.mr = None
61+
4462
# TODO: handle ownership? (_mr could be None)
45-
__slots__ = ("_ptr", "_size", "_mr")
63+
__slots__ = ("__weakref__", "_mnff")
4664

4765
def __init__(self, ptr, size, mr: MemoryResource = None):
48-
self._ptr = ptr
49-
self._size = size
50-
self._mr = mr
51-
52-
def __del__(self):
53-
"""Return close(self)."""
54-
self.close()
66+
self._mnff = Buffer._MembersNeededForFinalize(self, ptr, size, mr)
5567

5668
def close(self, stream=None):
5769
"""Deallocate this buffer asynchronously on the given stream.
@@ -67,47 +79,42 @@ def close(self, stream=None):
6779
the default stream.
6880
6981
"""
70-
if self._ptr and self._mr is not None:
71-
if stream is None:
72-
stream = default_stream()
73-
self._mr.deallocate(self._ptr, self._size, stream)
74-
self._ptr = 0
75-
self._mr = None
82+
self._mnff.close(stream)
7683

7784
@property
7885
def handle(self):
7986
"""Return the buffer handle object."""
80-
return self._ptr
87+
return self._mnff.ptr
8188

8289
@property
8390
def size(self):
8491
"""Return the memory size of this buffer."""
85-
return self._size
92+
return self._mnff.size
8693

8794
@property
8895
def memory_resource(self) -> MemoryResource:
8996
"""Return the memory resource associated with this buffer."""
90-
return self._mr
97+
return self._mnff.mr
9198

9299
@property
93100
def is_device_accessible(self) -> bool:
94101
"""Return True if this buffer can be accessed by the GPU, otherwise False."""
95-
if self._mr is not None:
96-
return self._mr.is_device_accessible
102+
if self._mnff.mr is not None:
103+
return self._mnff.mr.is_device_accessible
97104
raise NotImplementedError
98105

99106
@property
100107
def is_host_accessible(self) -> bool:
101108
"""Return True if this buffer can be accessed by the CPU, otherwise False."""
102-
if self._mr is not None:
103-
return self._mr.is_host_accessible
109+
if self._mnff.mr is not None:
110+
return self._mnff.mr.is_host_accessible
104111
raise NotImplementedError
105112

106113
@property
107114
def device_id(self) -> int:
108115
"""Return the device ordinal of this buffer."""
109-
if self._mr is not None:
110-
return self._mr.device_id
116+
if self._mnff.mr is not None:
117+
return self._mnff.mr.device_id
111118
raise NotImplementedError
112119

113120
def copy_to(self, dst: Buffer = None, *, stream) -> Buffer:
@@ -129,12 +136,12 @@ def copy_to(self, dst: Buffer = None, *, stream) -> Buffer:
129136
if stream is None:
130137
raise ValueError("stream must be provided")
131138
if dst is None:
132-
if self._mr is None:
139+
if self._mnff.mr is None:
133140
raise ValueError("a destination buffer must be provided")
134-
dst = self._mr.allocate(self._size, stream)
135-
if dst._size != self._size:
141+
dst = self._mnff.mr.allocate(self._mnff.size, stream)
142+
if dst._mnff.size != self._mnff.size:
136143
raise ValueError("buffer sizes mismatch between src and dst")
137-
handle_return(cuda.cuMemcpyAsync(dst._ptr, self._ptr, self._size, stream._handle))
144+
handle_return(cuda.cuMemcpyAsync(dst._mnff.ptr, self._mnff.ptr, self._mnff.size, stream.handle))
138145
return dst
139146

140147
def copy_from(self, src: Buffer, *, stream):
@@ -151,9 +158,9 @@ def copy_from(self, src: Buffer, *, stream):
151158
"""
152159
if stream is None:
153160
raise ValueError("stream must be provided")
154-
if src._size != self._size:
161+
if src._mnff.size != self._mnff.size:
155162
raise ValueError("buffer sizes mismatch between src and dst")
156-
handle_return(cuda.cuMemcpyAsync(self._ptr, src._ptr, self._size, stream._handle))
163+
handle_return(cuda.cuMemcpyAsync(self._mnff.ptr, src._mnff.ptr, self._mnff.size, stream.handle))
157164

158165
def __dlpack__(
159166
self,
@@ -242,13 +249,13 @@ def __init__(self, dev_id):
242249
def allocate(self, size, stream=None) -> Buffer:
243250
if stream is None:
244251
stream = default_stream()
245-
ptr = handle_return(cuda.cuMemAllocFromPoolAsync(size, self._handle, stream._handle))
252+
ptr = handle_return(cuda.cuMemAllocFromPoolAsync(size, self._handle, stream.handle))
246253
return Buffer(ptr, size, self)
247254

248255
def deallocate(self, ptr, size, stream=None):
249256
if stream is None:
250257
stream = default_stream()
251-
handle_return(cuda.cuMemFreeAsync(ptr, stream._handle))
258+
handle_return(cuda.cuMemFreeAsync(ptr, stream.handle))
252259

253260
@property
254261
def is_device_accessible(self) -> bool:

cuda_core/cuda/core/experimental/_module.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,7 @@ def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
148148
self._module = module
149149
self._sym_map = {} if symbol_mapping is None else symbol_mapping
150150

151-
def __del__(self):
152-
# TODO: do we want to unload? Probably not..
153-
pass
151+
# TODO: do we want to unload in a finalizer? Probably not..
154152

155153
def get_kernel(self, name):
156154
"""Return the :obj:`Kernel` of a specified name from this object code.

cuda_core/cuda/core/experimental/_program.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
import weakref
6+
57
from cuda import nvrtc
68
from cuda.core.experimental._module import ObjectCode
79
from cuda.core.experimental._utils import handle_return
@@ -24,12 +26,25 @@ class Program:
2426
2527
"""
2628

27-
__slots__ = ("_handle", "_backend")
29+
class _MembersNeededForFinalize:
30+
__slots__ = ("handle",)
31+
32+
def __init__(self, program_obj, handle):
33+
self.handle = handle
34+
weakref.finalize(program_obj, self.close)
35+
36+
def close(self):
37+
if self.handle is not None:
38+
handle_return(nvrtc.nvrtcDestroyProgram(self.handle))
39+
self.handle = None
40+
41+
__slots__ = ("__weakref__", "_mnff", "_backend")
2842
_supported_code_type = ("c++",)
2943
_supported_target_type = ("ptx", "cubin", "ltoir")
3044

3145
def __init__(self, code, code_type):
32-
self._handle = None
46+
self._mnff = Program._MembersNeededForFinalize(self, None)
47+
3348
if code_type not in self._supported_code_type:
3449
raise NotImplementedError
3550

@@ -38,20 +53,14 @@ def __init__(self, code, code_type):
3853
raise TypeError
3954
# TODO: support pre-loaded headers & include names
4055
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
41-
self._handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
56+
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
4257
self._backend = "nvrtc"
4358
else:
4459
raise NotImplementedError
4560

46-
def __del__(self):
47-
"""Return close(self)."""
48-
self.close()
49-
5061
def close(self):
5162
"""Destroy this program."""
52-
if self._handle is not None:
53-
handle_return(nvrtc.nvrtcDestroyProgram(self._handle))
54-
self._handle = None
63+
self._mnff.close()
5564

5665
def compile(self, target_type, options=(), name_expressions=(), logs=None):
5766
"""Compile the program with a specific compilation type.
@@ -84,29 +93,29 @@ def compile(self, target_type, options=(), name_expressions=(), logs=None):
8493
if self._backend == "nvrtc":
8594
if name_expressions:
8695
for n in name_expressions:
87-
handle_return(nvrtc.nvrtcAddNameExpression(self._handle, n.encode()), handle=self._handle)
96+
handle_return(nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), handle=self._mnff.handle)
8897
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
8998
options = list(o.encode() for o in options)
90-
handle_return(nvrtc.nvrtcCompileProgram(self._handle, len(options), options), handle=self._handle)
99+
handle_return(nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options), handle=self._mnff.handle)
91100

92101
size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size")
93102
comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}")
94-
size = handle_return(size_func(self._handle), handle=self._handle)
103+
size = handle_return(size_func(self._mnff.handle), handle=self._mnff.handle)
95104
data = b" " * size
96-
handle_return(comp_func(self._handle, data), handle=self._handle)
105+
handle_return(comp_func(self._mnff.handle, data), handle=self._mnff.handle)
97106

98107
symbol_mapping = {}
99108
if name_expressions:
100109
for n in name_expressions:
101110
symbol_mapping[n] = handle_return(
102-
nvrtc.nvrtcGetLoweredName(self._handle, n.encode()), handle=self._handle
111+
nvrtc.nvrtcGetLoweredName(self._mnff.handle, n.encode()), handle=self._mnff.handle
103112
)
104113

105114
if logs is not None:
106-
logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._handle), handle=self._handle)
115+
logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._mnff.handle), handle=self._mnff.handle)
107116
if logsize > 1:
108117
log = b" " * logsize
109-
handle_return(nvrtc.nvrtcGetProgramLog(self._handle, log), handle=self._handle)
118+
handle_return(nvrtc.nvrtcGetProgramLog(self._mnff.handle, log), handle=self._mnff.handle)
110119
logs.write(log.decode())
111120

112121
# TODO: handle jit_options for ptx?
@@ -121,4 +130,4 @@ def backend(self):
121130
@property
122131
def handle(self):
123132
"""Return the program handle object."""
124-
return self._handle
133+
return self._mnff.handle

0 commit comments

Comments
 (0)