Skip to content

Commit e966189

Browse files
committed
add docstring, copyright header, and switch finalizer pattern
1 parent c8a8dcb commit e966189

File tree

1 file changed

+52
-16
lines changed

1 file changed

+52
-16
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,49 @@ def __post_init__(self):
193193

194194

195195
class Linker:
196-
__slots__ = ("__weakref__", "_handle", "_options")
196+
"""
197+
Linker class for managing the linking of object codes with specified options.
198+
199+
Parameters
200+
----------
201+
object_codes : ObjectCode
202+
One or more ObjectCode objects to be linked.
203+
options : LinkerOptions, optional
204+
Options for the linker. If not provided, default options will be used.
205+
206+
Attributes
207+
----------
208+
_options : LinkerOptions
209+
The options used for the linker.
210+
_handle : handle
211+
The handle to the linker created by nvjitlink.
212+
213+
Methods
214+
-------
215+
_add_code_object(object_code)
216+
Adds an object code to the linker.
217+
close()
218+
Closes the linker and releases resources.
219+
"""
220+
221+
class _MembersNeededForFinalize:
222+
__slots__ = ("handle",)
223+
224+
def __init__(self, program_obj, handle):
225+
self.handle = handle
226+
weakref.finalize(program_obj, self.close)
227+
228+
def close(self):
229+
if self.handle is not None:
230+
nvjitlink.destroy(self.handle)
231+
self.handle = None
232+
233+
__slots__ = ("__weakref__", "_mnff", "_options")
197234

198235
def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):
199236
self._options = options = check_or_create_options(LinkerOptions, options, "Linker options")
200-
self._handle = nvjitlink.create(len(options.formatted_options), options.formatted_options)
237+
self._mnff.handle = nvjitlink.create(len(options.formatted_options), options.formatted_options)
238+
self._mnff = Linker._MembersNeededForFinalize(self, None)
201239

202240
if len(object_codes) == 0:
203241
raise ValueError("At least one ObjectCode object must be provided")
@@ -212,39 +250,39 @@ def _add_code_object(self, object_code: ObjectCode):
212250
data = object_code._module
213251
assert isinstance(data, bytes)
214252
nvjitlink.add_data(
215-
self._handle,
253+
self._mnff.handle,
216254
self._input_type_from_code_type(object_code._code_type),
217255
data,
218256
len(data),
219257
f"{object_code._handle}_{object_code._code_type}",
220258
)
221259

222260
def link(self, target_type) -> ObjectCode:
223-
nvjitlink.complete(self._handle)
261+
nvjitlink.complete(self._mnff.handle)
224262
if target_type not in ("cubin", "ptx"):
225263
raise ValueError(f"Unsupported target type: {target_type}")
226264
code = None
227265
if target_type == "cubin":
228-
cubin_size = nvjitlink.get_linked_cubin_size(self._handle)
266+
cubin_size = nvjitlink.get_linked_cubin_size(self._mnff.handle)
229267
code = bytearray(cubin_size)
230-
nvjitlink.get_linked_cubin(self._handle, code)
268+
nvjitlink.get_linked_cubin(self._mnff.handle, code)
231269
else:
232-
ptx_size = nvjitlink.get_linked_ptx_size(self._handle)
270+
ptx_size = nvjitlink.get_linked_ptx_size(self._mnff.handle)
233271
code = bytearray(ptx_size)
234-
nvjitlink.get_linked_ptx(self._handle, code)
272+
nvjitlink.get_linked_ptx(self._mnff.handle, code)
235273

236274
return ObjectCode(bytes(code), target_type)
237275

238276
def get_error_log(self) -> str:
239-
log_size = nvjitlink.get_error_log_size(self._handle)
277+
log_size = nvjitlink.get_error_log_size(self._mnff.handle)
240278
log = bytearray(log_size)
241-
nvjitlink.get_error_log(self._handle, log)
279+
nvjitlink.get_error_log(self._mnff.handle, log)
242280
return log.decode()
243281

244282
def get_info_log(self) -> str:
245-
log_size = nvjitlink.get_info_log_size(self._handle)
283+
log_size = nvjitlink.get_info_log_size(self._mnff.handle)
246284
log = bytearray(log_size)
247-
nvjitlink.get_info_log(self._handle, log)
285+
nvjitlink.get_info_log(self._mnff.handle, log)
248286
return log.decode()
249287

250288
def _input_type_from_code_type(self, code_type: str) -> nvjitlink.InputType:
@@ -265,9 +303,7 @@ def _input_type_from_code_type(self, code_type: str) -> nvjitlink.InputType:
265303

266304
@property
267305
def handle(self) -> int:
268-
return self._handle
306+
return self._mnff.handle
269307

270308
def close(self):
271-
if self._handle is not None:
272-
nvjitlink.destroy(self._handle)
273-
self._handle = None
309+
self._mnff.close()

0 commit comments

Comments
 (0)