Skip to content

Commit 71762af

Browse files
Implement Kernel.num_arguments, and Kernel.arguments_info (#612)
* Implement Kernel.num_arguments, and Kernel.arguments_info * Factor out common logic between num_arguments and arguments_info properties Also parametrize test to check arguments_info to check with all int, and all short arguments. * Use namedtuple turn to pack offset/size for individual kernel argument * Add feature entry to 0.3.0 release notes * Move ParamInfo to class context Used modern namedtuple instance constructor. * _get_arguments_info mustn't treat all error code as equal As `arg_pos` is being incremented, eventually `CUDA_ERROR_INVALID_VALUE` is expected when `arg_pos` exceeds the number of kernel arguments. But other errors like CUDA_ERROR_INVALID_CONTEXT are possible, and they should be treated as errors. Failure to do that was behind the unintuitive behavior I reported in the PR. * Add skip if testing with compute sanitizer, skip if cuda version < 12 Also add a test that if cuda is not initialized, an error is raised when checking for the num_arguments. * Implement suggestions from @leofang 1. Narrow fixture scope 2. Rename fixture to cuda12_prerequisite_check that provide a boolean rather than a pair of versions * Use ParamInfo instead of NamedTuple is annotation This required moving ParamInfo definition from class scope to module scope, since referencing Kernel.ParamInfo from annotations of methods of the Kernel class results in error that Kernel class does not yet exist.
1 parent d1072a2 commit 71762af

File tree

4 files changed

+151
-5
lines changed

4 files changed

+151
-5
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from collections import namedtuple
56
from typing import Optional, Union
67
from warnings import warn
78

@@ -43,6 +44,7 @@ def _lazy_init():
4344
"data": driver.cuLibraryLoadData,
4445
"kernel": driver.cuLibraryGetKernel,
4546
"attribute": driver.cuKernelGetAttribute,
47+
"paraminfo": driver.cuKernelGetParamInfo,
4648
}
4749
_kernel_ctypes = (driver.CUfunction, driver.CUkernel)
4850
else:
@@ -182,6 +184,9 @@ def cluster_scheduling_policy_preference(self, device_id: int = None) -> int:
182184
)
183185

184186

187+
ParamInfo = namedtuple("ParamInfo", ["offset", "size"])
188+
189+
185190
class Kernel:
186191
"""Represent a compiled kernel that had been loaded onto the device.
187192
@@ -215,6 +220,36 @@ def attributes(self) -> KernelAttributes:
215220
self._attributes = KernelAttributes._init(self._handle)
216221
return self._attributes
217222

223+
def _get_arguments_info(self, param_info=False) -> tuple[int, list[ParamInfo]]:
224+
attr_impl = self.attributes
225+
if attr_impl._backend_version != "new":
226+
raise NotImplementedError("New backend is required")
227+
arg_pos = 0
228+
param_info_data = []
229+
while True:
230+
result = attr_impl._loader["paraminfo"](self._handle, arg_pos)
231+
if result[0] != driver.CUresult.CUDA_SUCCESS:
232+
break
233+
if param_info:
234+
p_info = ParamInfo(offset=result[1], size=result[2])
235+
param_info_data.append(p_info)
236+
arg_pos = arg_pos + 1
237+
if result[0] != driver.CUresult.CUDA_ERROR_INVALID_VALUE:
238+
handle_return(result)
239+
return arg_pos, param_info_data
240+
241+
@property
242+
def num_arguments(self) -> int:
243+
"""int : The number of arguments of this function"""
244+
num_args, _ = self._get_arguments_info()
245+
return num_args
246+
247+
@property
248+
def arguments_info(self) -> list[ParamInfo]:
249+
"""list[ParamInfo]: (offset, size) for each argument of this function"""
250+
_, param_info = self._get_arguments_info(param_info=True)
251+
return param_info
252+
218253
# TODO: implement from_handle()
219254

220255

cuda_core/docs/source/release/0.3.0-notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Breaking Changes
2020
New features
2121
------------
2222

23+
- :class:`Kernel` adds :property:`Kernel.num_arguments` and :property:`Kernel.arguments_info` for introspection of kernel arguments. (#612)
2324

2425
New examples
2526
------------

cuda_core/tests/conftest.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,45 @@ def init_cuda():
2626
device = Device()
2727
device.set_current()
2828
yield
29-
_device_unset_current()
29+
_ = _device_unset_current()
3030

3131

32-
def _device_unset_current():
32+
def _device_unset_current() -> bool:
33+
"""Pop current CUDA context.
34+
35+
Returns True if context was popped, False it the stack was empty.
36+
"""
3337
ctx = handle_return(driver.cuCtxGetCurrent())
3438
if int(ctx) == 0:
3539
# no active context, do nothing
36-
return
40+
return False
3741
handle_return(driver.cuCtxPopCurrent())
3842
if hasattr(_device._tls, "devices"):
3943
del _device._tls.devices
44+
return True
4045

4146

4247
@pytest.fixture(scope="function")
4348
def deinit_cuda():
4449
# TODO: rename this to e.g. deinit_context
4550
yield
46-
_device_unset_current()
51+
_ = _device_unset_current()
52+
53+
54+
@pytest.fixture(scope="function")
55+
def deinit_all_contexts_function():
56+
def pop_all_contexts():
57+
max_iters = 256
58+
for _ in range(max_iters):
59+
if _device_unset_current():
60+
# context was popped, continue until stack is empty
61+
continue
62+
# no active context, we are ready
63+
break
64+
else:
65+
raise RuntimeError(f"Number of iterations popping current CUDA contexts, exceded {max_iters}")
66+
67+
return pop_all_contexts
4768

4869

4970
# samples relying on cffi could fail as the modules cannot be imported

cuda_core/tests/test_module.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
# Copyright 2024 NVIDIA Corporation. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import ctypes
45
import warnings
56

67
import pytest
8+
from conftest import skipif_testing_with_compute_sanitizer
79

810
import cuda.core.experimental
911
from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system
12+
from cuda.core.experimental._utils.cuda_utils import CUDAError, driver, get_binding_version, handle_return
1013

11-
SAXPY_KERNEL = """
14+
SAXPY_KERNEL = r"""
1215
template<typename T>
1316
__global__ void saxpy(const T a,
1417
const T* x,
@@ -23,6 +26,15 @@
2326
"""
2427

2528

29+
@pytest.fixture(scope="module")
30+
def cuda12_prerequisite_check():
31+
# binding availability depends on cuda-python version
32+
# and version of underlying CUDA toolkit
33+
_py_major_ver, _ = get_binding_version()
34+
_driver_ver = handle_return(driver.cuDriverGetVersion())
35+
return _py_major_ver >= 12 and _driver_ver >= 12000
36+
37+
2638
def test_kernel_attributes_init_disabled():
2739
with pytest.raises(RuntimeError, match=r"^KernelAttributes cannot be instantiated directly\."):
2840
cuda.core.experimental._module.KernelAttributes() # Ensure back door is locked.
@@ -156,3 +168,80 @@ def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path):
156168
def test_object_code_handle(get_saxpy_object_code):
157169
mod = get_saxpy_object_code
158170
assert mod.handle is not None
171+
172+
173+
@skipif_testing_with_compute_sanitizer
174+
def test_saxpy_arguments(get_saxpy_kernel, cuda12_prerequisite_check):
175+
if not cuda12_prerequisite_check:
176+
pytest.skip("Test requires CUDA 12")
177+
krn, _ = get_saxpy_kernel
178+
179+
assert krn.num_arguments == 5
180+
181+
assert "ParamInfo" in str(type(krn).arguments_info.fget.__annotations__)
182+
arg_info = krn.arguments_info
183+
n_args = len(arg_info)
184+
assert n_args == krn.num_arguments
185+
186+
class ExpectedStruct(ctypes.Structure):
187+
_fields_ = [
188+
("a", ctypes.c_float),
189+
("x", ctypes.POINTER(ctypes.c_float)),
190+
("y", ctypes.POINTER(ctypes.c_float)),
191+
("out", ctypes.POINTER(ctypes.c_float)),
192+
("N", ctypes.c_size_t),
193+
]
194+
195+
offsets = [p.offset for p in arg_info]
196+
sizes = [p.size for p in arg_info]
197+
members = [getattr(ExpectedStruct, name) for name, _ in ExpectedStruct._fields_]
198+
expected_offsets = tuple(m.offset for m in members)
199+
assert all(actual == expected for actual, expected in zip(offsets, expected_offsets))
200+
expected_sizes = tuple(m.size for m in members)
201+
assert all(actual == expected for actual, expected in zip(sizes, expected_sizes))
202+
203+
204+
@skipif_testing_with_compute_sanitizer
205+
@pytest.mark.parametrize("nargs", [0, 1, 2, 3, 16])
206+
@pytest.mark.parametrize("c_type_name,c_type", [("int", ctypes.c_int), ("short", ctypes.c_short)], ids=["int", "short"])
207+
def test_num_arguments(init_cuda, nargs, c_type_name, c_type, cuda12_prerequisite_check):
208+
if not cuda12_prerequisite_check:
209+
pytest.skip("Test requires CUDA 12")
210+
args_str = ", ".join([f"{c_type_name} p_{i}" for i in range(nargs)])
211+
src = f"__global__ void foo{nargs}({args_str}) {{ }}"
212+
prog = Program(src, code_type="c++")
213+
mod = prog.compile(
214+
"cubin",
215+
name_expressions=(f"foo{nargs}",),
216+
)
217+
krn = mod.get_kernel(f"foo{nargs}")
218+
assert krn.num_arguments == nargs
219+
220+
class ExpectedStruct(ctypes.Structure):
221+
_fields_ = [(f"arg_{i}", c_type) for i in range(nargs)]
222+
223+
members = tuple(getattr(ExpectedStruct, f"arg_{i}") for i in range(nargs))
224+
225+
arg_info = krn.arguments_info
226+
assert all([actual.offset == expected.offset for actual, expected in zip(arg_info, members)])
227+
assert all([actual.size == expected.size for actual, expected in zip(arg_info, members)])
228+
229+
230+
@skipif_testing_with_compute_sanitizer
231+
def test_num_args_error_handling(deinit_all_contexts_function, cuda12_prerequisite_check):
232+
if not cuda12_prerequisite_check:
233+
pytest.skip("Test requires CUDA 12")
234+
src = "__global__ void foo(int a) { }"
235+
prog = Program(src, code_type="c++")
236+
mod = prog.compile(
237+
"cubin",
238+
name_expressions=("foo",),
239+
)
240+
krn = mod.get_kernel("foo")
241+
# empty driver's context stack using function from conftest
242+
deinit_all_contexts_function()
243+
# with no current context, cuKernelGetParamInfo would report
244+
# exception which we expect to handle by raising
245+
with pytest.raises(CUDAError):
246+
# assignment resolves linter error "B018: useless expression"
247+
_ = krn.num_arguments

0 commit comments

Comments
 (0)