Skip to content

Commit 9bb4651

Browse files
committed
organize ObjectCode __init__
1 parent b21669b commit 9bb4651

File tree

3 files changed

+32
-21
lines changed

3 files changed

+32
-21
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def link(self, target_type) -> ObjectCode:
439439
addr, size = handle_return(_driver.cuLinkComplete(self._mnff.handle))
440440
code = (ctypes.c_char * size).from_address(addr)
441441

442-
return ObjectCode(bytes(code), target_type)
442+
return ObjectCode._init(bytes(code), target_type)
443443

444444
def get_error_log(self) -> str:
445445
"""Get the error log generated by the linker.

cuda_core/cuda/core/experimental/_module.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5-
5+
from typing import Union
66
from warnings import warn
77

88
from cuda.core.experimental._utils import driver, get_binding_version, handle_return, precondition
@@ -220,6 +220,12 @@ class ObjectCode:
220220
221221
Loads the module library with specified module code and JIT options.
222222
223+
Note
224+
----
225+
The public constructor assumes that ``module`` is of code type "cubin".
226+
For all other possible code types (ex: "ptx"), only :class:`~cuda.core.experimental.Program`
227+
accepts them and returns an `ObjectCode` instance with its ``compile`` method.
228+
223229
Note
224230
----
225231
Usage under CUDA 11.x will only load to the current device
@@ -228,32 +234,32 @@ class ObjectCode:
228234
Parameters
229235
----------
230236
module : Union[bytes, str]
231-
Either a bytes object containing the module to load, or
232-
a file path string containing that module for loading.
233-
code_type : Any
234-
String of the compiled type.
235-
Supported options are "ptx", "cubin", "ltoir" and "fatbin".
236-
jit_options : Optional
237-
Mapping of JIT options to use during module loading.
238-
(Default to no options)
239-
symbol_mapping : Optional
240-
Keyword argument dictionary specifying how symbol names
241-
should be mapped before trying to retrieve them.
242-
(Default to no mappings)
243-
237+
Either a bytes object containing the cubin to load, or
238+
a file path string pointing to the cubin to load.
244239
"""
245240

246-
__slots__ = ("_handle", "_backend_version", "_jit_options", "_code_type", "_module", "_loader", "_sym_map")
241+
__slots__ = ("_handle", "_backend_version", "_code_type", "_module", "_loader", "_sym_map")
247242
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin")
248243

249-
def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
244+
def __init__(self, module: Union[bytes, str]):
245+
_lazy_init()
246+
247+
# handle is assigned during _lazy_load
248+
self._handle = None
249+
self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
250+
self._loader = _backend[self._backend_version]
251+
self._code_type = "cubin"
252+
self._module = module
253+
self._sym_map = {}
254+
255+
def _init(module, code_type, *, symbol_mapping=None):
256+
self = ObjectCode.__new__(ObjectCode)
250257
if code_type not in self._supported_code_type:
251258
raise ValueError
252259
_lazy_init()
253260

254261
# handle is assigned during _lazy_load
255262
self._handle = None
256-
self._jit_options = jit_options
257263

258264
self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
259265
self._loader = _backend[self._backend_version]
@@ -262,14 +268,19 @@ def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
262268
self._module = module
263269
self._sym_map = {} if symbol_mapping is None else symbol_mapping
264270

271+
return self
272+
265273
# TODO: do we want to unload in a finalizer? Probably not..
266274

267275
def _lazy_load_module(self, *args, **kwargs):
268276
if self._handle is not None:
269277
return
270278
module = self._module
271279
if isinstance(module, str):
272-
self._handle = handle_return(self._loader["file"](module))
280+
if self._backend_version == "new":
281+
self._handle = handle_return(self._loader["file"](module, [], [], 0, [], [], 0))
282+
else: # "old" backend
283+
self._handle = handle_return(self._loader["file"](module))
273284
else:
274285
assert isinstance(module, bytes)
275286
if self._backend_version == "new":

cuda_core/cuda/core/experimental/_program.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
386386
if not isinstance(code, str):
387387
raise TypeError("ptx Program expects code argument to be a string")
388388
self._linker = Linker(
389-
ObjectCode(code.encode(), code_type), options=self._translate_program_options(options)
389+
ObjectCode._init(code.encode(), code_type), options=self._translate_program_options(options)
390390
)
391391
self._backend = "linker"
392392
else:
@@ -472,7 +472,7 @@ def compile(self, target_type, name_expressions=(), logs=None):
472472
handle_return(nvrtc.nvrtcGetProgramLog(self._mnff.handle, log), handle=self._mnff.handle)
473473
logs.write(log.decode())
474474

475-
return ObjectCode(data, target_type, symbol_mapping=symbol_mapping)
475+
return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping)
476476

477477
if self._backend == "linker":
478478
return self._linker.link(target_type)

0 commit comments

Comments
 (0)