Skip to content

Commit 912ae11

Browse files
Add fixture deinit_context_function
Use in test_module.py::test_num_args_error_handling Add comments
1 parent 55f6d31 commit 912ae11

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

cuda_core/tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def deinit_cuda():
5151
_device_unset_current()
5252

5353

54+
@pytest.fixture(scope="function")
55+
def deinit_context_function():
56+
return _device_unset_current
57+
58+
5459
# samples relying on cffi could fail as the modules cannot be imported
5560
sys.path.append(os.getcwd())
5661

cuda_core/tests/test_module.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def test_saxpy_arguments(get_saxpy_kernel, cuda12_prerequisite_check):
184184

185185
assert krn.num_arguments == 5
186186

187+
assert "ParamInfo" in str(type(krn).arguments_info.fget.__annotations__)
187188
arg_info = krn.arguments_info
188189
n_args = len(arg_info)
189190
assert n_args == krn.num_arguments
@@ -233,7 +234,7 @@ class ExpectedStruct(ctypes.Structure):
233234

234235

235236
@skipif_testing_with_compute_sanitizer
236-
def test_num_args_error_handling(deinit_cuda, cuda12_prerequisite_check):
237+
def test_num_args_error_handling(deinit_context_function, cuda12_prerequisite_check):
237238
if not cuda12_prerequisite_check:
238239
pytest.skip("Test requires CUDA 12")
239240
src = "__global__ void foo(int a) { }"
@@ -243,6 +244,10 @@ def test_num_args_error_handling(deinit_cuda, cuda12_prerequisite_check):
243244
name_expressions=("foo",),
244245
)
245246
krn = mod.get_kernel("foo")
247+
# Unset current context using function from conftest
248+
deinit_context_function()
249+
# with no context, cuKernelGetParamInfo would report
250+
# exception which we expect to handle by raising
246251
with pytest.raises(CUDAError):
247252
# assignment resolves linter error "B018: useless expression"
248253
_ = krn.num_arguments

0 commit comments

Comments
 (0)