Skip to content

Commit 3c588e8

Browse files
Enable serialization/deserialization of ObjectCode instances (#660)
* implement __reduce__ * updates * inline function * fix reduce * apply patch * address reviews --------- Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]>
1 parent 1e5097c commit 3c588e8

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,14 @@ def _init(cls, module, code_type, *, symbol_mapping: Optional[dict] = None):
302302

303303
return self
304304

305+
@classmethod
306+
def _reduce_helper(self, module, code_type, symbol_mapping):
307+
# just for forwarding kwargs
308+
return ObjectCode._init(module, code_type, symbol_mapping=symbol_mapping)
309+
310+
def __reduce__(self):
311+
return ObjectCode._reduce_helper, (self._module, self._code_type, self._sym_map)
312+
305313
@staticmethod
306314
def from_cubin(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
307315
"""Create an :class:`ObjectCode` instance from an existing cubin.

cuda_core/tests/test_module.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import ctypes
5+
import pickle # nosec B403, B301
56
import warnings
67

78
import pytest
@@ -245,3 +246,13 @@ def test_num_args_error_handling(deinit_all_contexts_function, cuda12_prerequisi
245246
with pytest.raises(CUDAError):
246247
# assignment resolves linter error "B018: useless expression"
247248
_ = krn.num_arguments
249+
250+
251+
def test_module_serialization_roundtrip(get_saxpy_kernel):
252+
_, objcode = get_saxpy_kernel
253+
result = pickle.loads(pickle.dumps(objcode)) # nosec B403, B301
254+
255+
assert isinstance(result, ObjectCode)
256+
assert objcode.code == result.code
257+
assert objcode._sym_map == result._sym_map
258+
assert objcode._code_type == result._code_type

0 commit comments

Comments
 (0)