Skip to content

Commit 7587684

Browse files
committed
propagate py/driver ver check to launch
1 parent b64f337 commit 7587684

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

cuda_core/cuda/core/experimental/_launcher.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,11 @@ def launch(kernel, config, *kernel_args):
7676
kernel_args = ParamHolder(kernel_args)
7777
args_ptr = kernel_args.ptr
7878

79-
driver_ver = handle_return(cuda.cuDriverGetVersion())
80-
if driver_ver >= 12000:
79+
# Note: CUkernel can still be launched via the old cuLaunchKernel. We check ._backend
80+
# here not because of the CUfunction/CUkernel difference (which depends on whether the
81+
# "old" or "new" module loading APIs are in use), but only as a proxy to check if
82+
# both binding & driver versions support the "Ex" API, which is more feature rich.
83+
if kernel._backend == "new":
8184
drv_cfg = cuda.CUlaunchConfig()
8285
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
8386
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
@@ -86,7 +89,7 @@ def launch(kernel, config, *kernel_args):
8689
drv_cfg.numAttrs = 0 # TODO
8790
handle_return(cuda.cuLaunchKernelEx(
8891
drv_cfg, int(kernel._handle), args_ptr, 0))
89-
else:
92+
else: # "old" backend
9093
# TODO: check if config has any unsupported attrs
9194
handle_return(cuda.cuLaunchKernel(
9295
int(kernel._handle),

cuda_core/cuda/core/experimental/_module.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,27 @@
3232

3333
class Kernel:
3434

35-
__slots__ = ("_handle", "_module",)
35+
__slots__ = ("_handle", "_module", "_backend")
3636

3737
def __init__(self):
3838
raise NotImplementedError("directly constructing a Kernel instance is not supported")
3939

4040
@staticmethod
41-
def _from_obj(obj, mod):
41+
def _from_obj(obj, mod, backend):
4242
assert isinstance(obj, _kernel_ctypes)
4343
assert isinstance(mod, ObjectCode)
4444
ker = Kernel.__new__(Kernel)
4545
ker._handle = obj
4646
ker._module = mod
47+
ker._backend = backend
4748
return ker
4849

4950
# TODO: implement from_handle()
5051

5152

5253
class ObjectCode:
5354

54-
__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map")
55+
__slots__ = ("_handle", "_code_type", "_module", "_loader", "_loader_backend", "_sym_map")
5556
_supported_code_type = ("cubin", "ptx", "fatbin")
5657

5758
def __init__(self, module, code_type, jit_options=None, *,
@@ -62,6 +63,7 @@ def __init__(self, module, code_type, jit_options=None, *,
6263

6364
backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000) else "old"
6465
self._loader = _backend[backend]
66+
self._loader_backend = backend
6567

6668
if isinstance(module, str):
6769
if driver_ver < 12000 and jit_options is not None:
@@ -94,6 +96,6 @@ def get_kernel(self, name):
9496
except KeyError:
9597
name = name.encode()
9698
data = handle_return(self._loader["kernel"](self._handle, name))
97-
return Kernel._from_obj(data, self)
99+
return Kernel._from_obj(data, self, self._loader_backend)
98100

99101
# TODO: implement from_handle()

0 commit comments

Comments
 (0)