@@ -234,8 +234,9 @@ def close(self):
234
234
235
235
def __init__ (self , * object_codes : ObjectCode , options : LinkerOptions = None ):
236
236
self ._options = options = check_or_create_options (LinkerOptions , options , "Linker options" )
237
- self ._mnff .handle = nvjitlink .create (len (options .formatted_options ), options .formatted_options )
238
- self ._mnff = Linker ._MembersNeededForFinalize (self , None )
237
+ self ._mnff = Linker ._MembersNeededForFinalize (
238
+ self , nvjitlink .create (len (options .formatted_options ), options .formatted_options )
239
+ )
239
240
240
241
if len (object_codes ) == 0 :
241
242
raise ValueError ("At least one ObjectCode object must be provided" )
@@ -244,8 +245,6 @@ def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):
244
245
assert isinstance (code , ObjectCode )
245
246
self ._add_code_object (code )
246
247
247
- weakref .finalize (self , self .close )
248
-
249
248
def _add_code_object (self , object_code : ObjectCode ):
250
249
data = object_code ._module
251
250
assert isinstance (data , bytes )
@@ -257,19 +256,21 @@ def _add_code_object(self, object_code: ObjectCode):
257
256
f"{ object_code ._handle } _{ object_code ._code_type } " ,
258
257
)
259
258
259
+ _get_linked_methods = {
260
+ "cubin" : (nvjitlink .get_linked_cubin_size , nvjitlink .get_linked_cubin ),
261
+ "ptx" : (nvjitlink .get_linked_ptx_size , nvjitlink .get_linked_ptx ),
262
+ }
263
+
260
264
def link (self , target_type ) -> ObjectCode :
261
265
nvjitlink .complete (self ._mnff .handle )
262
- if target_type not in ("cubin" , "ptx" ):
266
+ get_linked = self ._get_linked_methods .get (target_type )
267
+ if get_linked is None :
263
268
raise ValueError (f"Unsupported target type: { target_type } " )
264
- code = None
265
- if target_type == "cubin" :
266
- cubin_size = nvjitlink .get_linked_cubin_size (self ._mnff .handle )
267
- code = bytearray (cubin_size )
268
- nvjitlink .get_linked_cubin (self ._mnff .handle , code )
269
- else :
270
- ptx_size = nvjitlink .get_linked_ptx_size (self ._mnff .handle )
271
- code = bytearray (ptx_size )
272
- nvjitlink .get_linked_ptx (self ._mnff .handle , code )
269
+
270
+ get_size , get_code = get_linked
271
+ size = get_size (self ._mnff .handle )
272
+ code = bytearray (size )
273
+ get_code (self ._mnff .handle , code )
273
274
274
275
return ObjectCode (bytes (code ), target_type )
275
276
@@ -285,21 +286,22 @@ def get_info_log(self) -> str:
285
286
nvjitlink .get_info_log (self ._mnff .handle , log )
286
287
return log .decode ()
287
288
289
+ _input_types = {
290
+ "ptx" : nvjitlink .InputType .PTX ,
291
+ "cubin" : nvjitlink .InputType .CUBIN ,
292
+ "fatbin" : nvjitlink .InputType .FATBIN ,
293
+ "ltoir" : nvjitlink .InputType .LTOIR ,
294
+ "object" : nvjitlink .InputType .OBJECT ,
295
+ }
296
+
288
297
def _input_type_from_code_type (self , code_type : str ) -> nvjitlink .InputType :
289
298
# this list is based on the supported values for code_type in the ObjectCode class definition.
290
299
# nvjitlink supports other options for input type
291
- if code_type == "ptx" :
292
- return nvjitlink .InputType .PTX
293
- elif code_type == "cubin" :
294
- return nvjitlink .InputType .CUBIN
295
- elif code_type == "fatbin" :
296
- return nvjitlink .InputType .FATBIN
297
- elif code_type == "ltoir" :
298
- return nvjitlink .InputType .LTOIR
299
- elif code_type == "object" :
300
- return nvjitlink .InputType .OBJECT
301
- else :
300
+ input_type = self ._input_types .get (code_type )
301
+
302
+ if input_type is None :
302
303
raise ValueError (f"Unknown code_type associated with ObjectCode: { code_type } " )
304
+ return input_type
303
305
304
306
@property
305
307
def handle (self ) -> int :
0 commit comments