Skip to content

Commit 9f6496f

Browse files
committed
implement from_cubin; add docs; ensure get_kernel cannot be called with lto-ir
1 parent 1d45767 commit 9f6496f

File tree

4 files changed

+36
-34
lines changed

4 files changed

+36
-34
lines changed

cuda_core/cuda/core/experimental/_event.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def close(self):
6666

6767
def __init__(self):
6868
raise NotImplementedError(
69-
"directly creating an Event object can be ambiguous. Please call call Stream.record()."
69+
"directly creating an Event object can be ambiguous. Please call Stream.record()."
7070
)
7171

7272
@staticmethod

cuda_core/cuda/core/experimental/_module.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -213,53 +213,38 @@ def attributes(self):
213213

214214

215215
class ObjectCode:
216-
"""Represent a compiled program that was loaded onto the device.
216+
"""Represent a compiled program to be loaded onto the device.
217217
218218
This object provides a unified interface for different types of
219-
compiled programs that are loaded onto the device.
220-
221-
Loads the module library with specified module code and JIT options.
219+
compiled programs that will be loaded onto the device.
222220
223221
Note
224222
----
225-
The public constructor assumes that ``module`` is of code type "cubin".
226-
For all other possible code types (ex: "ptx"), only :class:`~cuda.core.experimental.Program`
227-
accepts them and returns an `ObjectCode` instance with its ``compile`` method.
223+
This class has no default constructor. If you already have a cubin that you would
224+
like to load, use the :meth:`from_cubin` alternative constructor. For all other
225+
possible code types (ex: "ptx"), only :class:`~cuda.core.experimental.Program`
226+
accepts them and returns an :class:`ObjectCode` instance with its
227+
:meth:`~cuda.core.experimental.Program.compile` method.
228228
229229
Note
230230
----
231231
Usage under CUDA 11.x will only load to the current device
232232
context.
233-
234-
Parameters
235-
----------
236-
module : Union[bytes, str]
237-
Either a bytes object containing the cubin to load, or
238-
a file path string pointing to the cubin to load.
239-
symbol_mapping : Optional[dict]
240-
A dictionary specifying how the unmangled symbol names (as keys)
241-
should be mapped to the mangled names before trying to retrieve
242-
them (default to no mappings).
243233
"""
244234

245235
__slots__ = ("_handle", "_backend_version", "_code_type", "_module", "_loader", "_sym_map")
246236
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin")
247237

248-
def __init__(self, module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None):
249-
_lazy_init()
250-
251-
# handle is assigned during _lazy_load
252-
self._handle = None
253-
self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
254-
self._loader = _backend[self._backend_version]
255-
self._code_type = "cubin"
256-
self._module = module
257-
self._sym_map = {} if symbol_mapping is None else symbol_mapping
238+
def __init__(self):
239+
raise NotImplementedError(
240+
"directly creating an ObjectCode object can be ambiguous. Please either call Program.compile() "
241+
"or one of the ObjectCode.from_*() constructors"
242+
)
258243

244+
@staticmethod
259245
def _init(module, code_type, *, symbol_mapping: Optional[dict] = None):
260246
self = ObjectCode.__new__(ObjectCode)
261-
if code_type not in self._supported_code_type:
262-
raise ValueError
247+
assert code_type in self._supported_code_type, f"{code_type=} is not supported"
263248
_lazy_init()
264249

265250
# handle is assigned during _lazy_load
@@ -274,6 +259,22 @@ def _init(module, code_type, *, symbol_mapping: Optional[dict] = None):
274259

275260
return self
276261

262+
@staticmethod
263+
def from_cubin(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
264+
"""Create an :class:`ObjectCode` instance from an existing cubin.
265+
266+
Parameters
267+
----------
268+
module : Union[bytes, str]
269+
Either a bytes object containing the in-memory cubin to load, or
270+
a file path string pointing to the on-disk cubin to load.
271+
symbol_mapping : Optional[dict]
272+
A dictionary specifying how the unmangled symbol names (as keys)
273+
should be mapped to the mangled names before trying to retrieve
274+
them (default to no mappings).
275+
"""
276+
return ObjectCode._init(module, "cubin", symbol_mapping=symbol_mapping)
277+
277278
# TODO: do we want to unload in a finalizer? Probably not..
278279

279280
def _lazy_load_module(self, *args, **kwargs):
@@ -307,12 +308,12 @@ def get_kernel(self, name):
307308
Newly created kernel object.
308309
309310
"""
311+
if self._code_type not in ("cubin", "ptx", "fatbin"):
312+
raise RuntimeError(f"get_kernel() is not supported for {self._code_type}")
310313
try:
311314
name = self._sym_map[name]
312315
except KeyError:
313316
name = name.encode()
314317

315318
data = handle_return(self._loader["kernel"](self._handle, name))
316319
return Kernel._from_obj(data, self)
317-
318-
# TODO: implement from_handle()

cuda_core/docs/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ CUDA compilation toolchain
3232

3333
Program
3434
Linker
35+
ObjectCode
3536

3637
:template: dataclass.rst
3738

cuda_core/tests/test_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_object_code_load_cubin(get_saxpy_kernel):
8989
cubin = mod._module
9090
sym_map = mod._sym_map
9191
assert isinstance(cubin, bytes)
92-
mod = ObjectCode(cubin, symbol_mapping=sym_map)
92+
mod = ObjectCode.from_cubin(cubin, symbol_mapping=sym_map)
9393
mod.get_kernel("saxpy<double>") # force loading
9494

9595

@@ -100,5 +100,5 @@ def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path):
100100
assert isinstance(cubin, bytes)
101101
cubin_file = tmp_path / "test.cubin"
102102
cubin_file.write_bytes(cubin)
103-
mod = ObjectCode(str(cubin_file), symbol_mapping=sym_map)
103+
mod = ObjectCode.from_cubin(str(cubin_file), symbol_mapping=sym_map)
104104
mod.get_kernel("saxpy<double>") # force loading

0 commit comments

Comments
 (0)