@@ -193,11 +193,49 @@ def __post_init__(self):
193
193
194
194
195
195
class Linker :
196
- __slots__ = ("__weakref__" , "_handle" , "_options" )
196
+ """
197
+ Linker class for managing the linking of object codes with specified options.
198
+
199
+ Parameters
200
+ ----------
201
+ object_codes : ObjectCode
202
+ One or more ObjectCode objects to be linked.
203
+ options : LinkerOptions, optional
204
+ Options for the linker. If not provided, default options will be used.
205
+
206
+ Attributes
207
+ ----------
208
+ _options : LinkerOptions
209
+ The options used for the linker.
210
+ _handle : handle
211
+ The handle to the linker created by nvjitlink.
212
+
213
+ Methods
214
+ -------
215
+ _add_code_object(object_code)
216
+ Adds an object code to the linker.
217
+ close()
218
+ Closes the linker and releases resources.
219
+ """
220
+
221
+ class _MembersNeededForFinalize :
222
+ __slots__ = ("handle" ,)
223
+
224
+ def __init__ (self , program_obj , handle ):
225
+ self .handle = handle
226
+ weakref .finalize (program_obj , self .close )
227
+
228
+ def close (self ):
229
+ if self .handle is not None :
230
+ nvjitlink .destroy (self .handle )
231
+ self .handle = None
232
+
233
+ __slots__ = ("__weakref__" , "_mnff" , "_options" )
197
234
198
235
def __init__ (self , * object_codes : ObjectCode , options : LinkerOptions = None ):
199
236
self ._options = options = check_or_create_options (LinkerOptions , options , "Linker options" )
200
- self ._handle = nvjitlink .create (len (options .formatted_options ), options .formatted_options )
237
+ self ._mnff .handle = nvjitlink .create (len (options .formatted_options ), options .formatted_options )
238
+ self ._mnff = Linker ._MembersNeededForFinalize (self , None )
201
239
202
240
if len (object_codes ) == 0 :
203
241
raise ValueError ("At least one ObjectCode object must be provided" )
@@ -212,39 +250,39 @@ def _add_code_object(self, object_code: ObjectCode):
212
250
data = object_code ._module
213
251
assert isinstance (data , bytes )
214
252
nvjitlink .add_data (
215
- self ._handle ,
253
+ self ._mnff . handle ,
216
254
self ._input_type_from_code_type (object_code ._code_type ),
217
255
data ,
218
256
len (data ),
219
257
f"{ object_code ._handle } _{ object_code ._code_type } " ,
220
258
)
221
259
222
260
def link (self , target_type ) -> ObjectCode :
223
- nvjitlink .complete (self ._handle )
261
+ nvjitlink .complete (self ._mnff . handle )
224
262
if target_type not in ("cubin" , "ptx" ):
225
263
raise ValueError (f"Unsupported target type: { target_type } " )
226
264
code = None
227
265
if target_type == "cubin" :
228
- cubin_size = nvjitlink .get_linked_cubin_size (self ._handle )
266
+ cubin_size = nvjitlink .get_linked_cubin_size (self ._mnff . handle )
229
267
code = bytearray (cubin_size )
230
- nvjitlink .get_linked_cubin (self ._handle , code )
268
+ nvjitlink .get_linked_cubin (self ._mnff . handle , code )
231
269
else :
232
- ptx_size = nvjitlink .get_linked_ptx_size (self ._handle )
270
+ ptx_size = nvjitlink .get_linked_ptx_size (self ._mnff . handle )
233
271
code = bytearray (ptx_size )
234
- nvjitlink .get_linked_ptx (self ._handle , code )
272
+ nvjitlink .get_linked_ptx (self ._mnff . handle , code )
235
273
236
274
return ObjectCode (bytes (code ), target_type )
237
275
238
276
def get_error_log (self ) -> str :
239
- log_size = nvjitlink .get_error_log_size (self ._handle )
277
+ log_size = nvjitlink .get_error_log_size (self ._mnff . handle )
240
278
log = bytearray (log_size )
241
- nvjitlink .get_error_log (self ._handle , log )
279
+ nvjitlink .get_error_log (self ._mnff . handle , log )
242
280
return log .decode ()
243
281
244
282
def get_info_log (self ) -> str :
245
- log_size = nvjitlink .get_info_log_size (self ._handle )
283
+ log_size = nvjitlink .get_info_log_size (self ._mnff . handle )
246
284
log = bytearray (log_size )
247
- nvjitlink .get_info_log (self ._handle , log )
285
+ nvjitlink .get_info_log (self ._mnff . handle , log )
248
286
return log .decode ()
249
287
250
288
def _input_type_from_code_type (self , code_type : str ) -> nvjitlink .InputType :
@@ -265,9 +303,7 @@ def _input_type_from_code_type(self, code_type: str) -> nvjitlink.InputType:
265
303
266
304
@property
267
305
def handle (self ) -> int :
268
- return self ._handle
306
+ return self ._mnff . handle
269
307
270
308
def close (self ):
271
- if self ._handle is not None :
272
- nvjitlink .destroy (self ._handle )
273
- self ._handle = None
309
+ self ._mnff .close ()
0 commit comments