Skip to content

Commit 17c3e10

Browse files
committed
address comments
1 parent e966189 commit 17c3e10

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,9 @@ def close(self):
234234

235235
def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):
236236
self._options = options = check_or_create_options(LinkerOptions, options, "Linker options")
237-
self._mnff.handle = nvjitlink.create(len(options.formatted_options), options.formatted_options)
238-
self._mnff = Linker._MembersNeededForFinalize(self, None)
237+
self._mnff = Linker._MembersNeededForFinalize(
238+
self, nvjitlink.create(len(options.formatted_options), options.formatted_options)
239+
)
239240

240241
if len(object_codes) == 0:
241242
raise ValueError("At least one ObjectCode object must be provided")
@@ -244,8 +245,6 @@ def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):
244245
assert isinstance(code, ObjectCode)
245246
self._add_code_object(code)
246247

247-
weakref.finalize(self, self.close)
248-
249248
def _add_code_object(self, object_code: ObjectCode):
250249
data = object_code._module
251250
assert isinstance(data, bytes)
@@ -257,19 +256,21 @@ def _add_code_object(self, object_code: ObjectCode):
257256
f"{object_code._handle}_{object_code._code_type}",
258257
)
259258

259+
_get_linked_methods = {
260+
"cubin": (nvjitlink.get_linked_cubin_size, nvjitlink.get_linked_cubin),
261+
"ptx": (nvjitlink.get_linked_ptx_size, nvjitlink.get_linked_ptx),
262+
}
263+
260264
def link(self, target_type) -> ObjectCode:
261265
nvjitlink.complete(self._mnff.handle)
262-
if target_type not in ("cubin", "ptx"):
266+
get_linked = self._get_linked_methods.get(target_type)
267+
if get_linked is None:
263268
raise ValueError(f"Unsupported target type: {target_type}")
264-
code = None
265-
if target_type == "cubin":
266-
cubin_size = nvjitlink.get_linked_cubin_size(self._mnff.handle)
267-
code = bytearray(cubin_size)
268-
nvjitlink.get_linked_cubin(self._mnff.handle, code)
269-
else:
270-
ptx_size = nvjitlink.get_linked_ptx_size(self._mnff.handle)
271-
code = bytearray(ptx_size)
272-
nvjitlink.get_linked_ptx(self._mnff.handle, code)
269+
270+
get_size, get_code = get_linked
271+
size = get_size(self._mnff.handle)
272+
code = bytearray(size)
273+
get_code(self._mnff.handle, code)
273274

274275
return ObjectCode(bytes(code), target_type)
275276

@@ -285,21 +286,22 @@ def get_info_log(self) -> str:
285286
nvjitlink.get_info_log(self._mnff.handle, log)
286287
return log.decode()
287288

289+
_input_types = {
290+
"ptx": nvjitlink.InputType.PTX,
291+
"cubin": nvjitlink.InputType.CUBIN,
292+
"fatbin": nvjitlink.InputType.FATBIN,
293+
"ltoir": nvjitlink.InputType.LTOIR,
294+
"object": nvjitlink.InputType.OBJECT,
295+
}
296+
288297
def _input_type_from_code_type(self, code_type: str) -> nvjitlink.InputType:
289298
# this list is based on the supported values for code_type in the ObjectCode class definition.
290299
# nvjitlink supports other options for input type
291-
if code_type == "ptx":
292-
return nvjitlink.InputType.PTX
293-
elif code_type == "cubin":
294-
return nvjitlink.InputType.CUBIN
295-
elif code_type == "fatbin":
296-
return nvjitlink.InputType.FATBIN
297-
elif code_type == "ltoir":
298-
return nvjitlink.InputType.LTOIR
299-
elif code_type == "object":
300-
return nvjitlink.InputType.OBJECT
301-
else:
300+
input_type = self._input_types.get(code_type)
301+
302+
if input_type is None:
302303
raise ValueError(f"Unknown code_type associated with ObjectCode: {code_type}")
304+
return input_type
303305

304306
@property
305307
def handle(self) -> int:

0 commit comments

Comments
 (0)