Skip to content

Commit 5349bd1

Browse files
committed
expose ObjectCode to public + fix file loading
1 parent 9bb4651 commit 5349bd1

File tree

4 files changed

+45
-16
lines changed

4 files changed

+45
-16
lines changed

cuda_core/cuda/core/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from cuda.core.experimental._event import EventOptions
88
from cuda.core.experimental._launcher import LaunchConfig, launch
99
from cuda.core.experimental._linker import Linker, LinkerOptions
10+
from cuda.core.experimental._module import ObjectCode
1011
from cuda.core.experimental._program import Program, ProgramOptions
1112
from cuda.core.experimental._stream import Stream, StreamOptions
1213
from cuda.core.experimental._system import System

cuda_core/cuda/core/experimental/_module.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5-
from typing import Union
5+
from typing import Optional, Union
66
from warnings import warn
77

88
from cuda.core.experimental._utils import driver, get_binding_version, handle_return, precondition
@@ -236,12 +236,16 @@ class ObjectCode:
236236
module : Union[bytes, str]
237237
Either a bytes object containing the cubin to load, or
238238
a file path string pointing to the cubin to load.
239+
symbol_mapping : Optional[dict]
240+
A dictionary specifying how the unmangled symbol names (as keys)
241+
should be mapped to the mangled names before trying to retrieve
242+
them (default to no mappings).
239243
"""
240244

241245
__slots__ = ("_handle", "_backend_version", "_code_type", "_module", "_loader", "_sym_map")
242246
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin")
243247

244-
def __init__(self, module: Union[bytes, str]):
248+
def __init__(self, module: Union[bytes, str], *, symbol_mapping: Optional[dict]=None):
245249
_lazy_init()
246250

247251
# handle is assigned during _lazy_load
@@ -250,9 +254,9 @@ def __init__(self, module: Union[bytes, str]):
250254
self._loader = _backend[self._backend_version]
251255
self._code_type = "cubin"
252256
self._module = module
253-
self._sym_map = {}
257+
self._sym_map = {} if symbol_mapping is None else symbol_mapping
254258

255-
def _init(module, code_type, *, symbol_mapping=None):
259+
def _init(module, code_type, *, symbol_mapping: Optional[dict]=None):
256260
self = ObjectCode.__new__(ObjectCode)
257261
if code_type not in self._supported_code_type:
258262
raise ValueError
@@ -278,9 +282,9 @@ def _lazy_load_module(self, *args, **kwargs):
278282
module = self._module
279283
if isinstance(module, str):
280284
if self._backend_version == "new":
281-
self._handle = handle_return(self._loader["file"](module, [], [], 0, [], [], 0))
285+
self._handle = handle_return(self._loader["file"](module.encode(), [], [], 0, [], [], 0))
282286
else: # "old" backend
283-
self._handle = handle_return(self._loader["file"](module))
287+
self._handle = handle_return(self._loader["file"](module.encode()))
284288
else:
285289
assert isinstance(module, bytes)
286290
if self._backend_version == "new":

cuda_core/docs/source/release/0.2.0-notes.rst

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,19 @@ Highlights
1212
- Add :class:`~DeviceProperties` to provide pythonic access to device properties.
1313
- Add kernel attributes to :class:`~Kernel`
1414

15-
Limitations
16-
-----------
17-
18-
- <TODO>
19-
2015
Breaking Changes
2116
----------------
2217

2318
- Change ``__cuda_stream__`` from attribute to method
2419
- The :meth:`~Program.compile` method no longer accepts the `options` argument. Instead, you can optionally pass an instance of :class:`~ProgramOptions` to the constructor of :obj:`~Program`.
25-
- The internal constructor of :class:`~ObjectCode` no longer accepts the jit_options argument. Options are provided to upstream :class:`~ProgramOptions` or :class:`~LinkerOptions` instead.
26-
- :meth: `~Device.properties` now provides an instance of :class:`~DeviceProperties` instead of a dictionary.
20+
- :meth: `~Device.properties` now provides an instance of :class:`~DeviceProperties` instead of a dictionary.
21+
22+
New features
23+
------------
24+
25+
- Expose :class:`ObjectCode` as a public API, which allows loading cubins from memory or disk. For loading other kinds of code types, please continue using :class:`Program`.
26+
27+
Limitations
28+
-----------
29+
30+
- <TODO>

cuda_core/tests/test_module.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytest
1111
from conftest import can_load_generated_ptx
1212

13-
from cuda.core.experimental import Program, ProgramOptions, system
13+
from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system
1414

1515

1616
@pytest.fixture(scope="function")
@@ -37,7 +37,7 @@ def get_saxpy_kernel(init_cuda):
3737
)
3838

3939
# run in single precision
40-
return mod.get_kernel("saxpy<float>")
40+
return mod.get_kernel("saxpy<float>"), mod
4141

4242

4343
@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new")
@@ -72,7 +72,7 @@ def test_get_kernel(init_cuda):
7272
],
7373
)
7474
def test_read_only_kernel_attributes(get_saxpy_kernel, attr, expected_type):
75-
kernel = get_saxpy_kernel
75+
kernel, _ = get_saxpy_kernel
7676
method = getattr(kernel.attributes, attr)
7777
# get the value without providing a device ordinal
7878
value = method()
@@ -82,3 +82,23 @@ def test_read_only_kernel_attributes(get_saxpy_kernel, attr, expected_type):
8282
for device in system.devices:
8383
value = method(device.device_id)
8484
assert isinstance(value, expected_type), f"Expected {attr} to be of type {expected_type}, but got {type(value)}"
85+
86+
87+
def test_object_code_load_cubin(get_saxpy_kernel):
88+
_, mod = get_saxpy_kernel
89+
cubin = mod._module
90+
sym_map = mod._sym_map
91+
assert isinstance(cubin, bytes)
92+
mod = ObjectCode(cubin, symbol_mapping=sym_map)
93+
ker = mod.get_kernel("saxpy<double>")
94+
95+
96+
def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path):
97+
_, mod = get_saxpy_kernel
98+
cubin = mod._module
99+
sym_map = mod._sym_map
100+
assert isinstance(cubin, bytes)
101+
cubin_file = tmp_path / "test.cubin"
102+
cubin_file.write_bytes(cubin)
103+
mod = ObjectCode(str(cubin_file), symbol_mapping=sym_map)
104+
ker = mod.get_kernel("saxpy<double>")

0 commit comments

Comments
 (0)