@@ -184,6 +184,7 @@ def test_saxpy_arguments(get_saxpy_kernel, cuda12_prerequisite_check):
184
184
185
185
assert krn .num_arguments == 5
186
186
187
+ assert "ParamInfo" in str (type (krn ).arguments_info .fget .__annotations__ )
187
188
arg_info = krn .arguments_info
188
189
n_args = len (arg_info )
189
190
assert n_args == krn .num_arguments
@@ -233,7 +234,7 @@ class ExpectedStruct(ctypes.Structure):
233
234
234
235
235
236
@skipif_testing_with_compute_sanitizer
236
- def test_num_args_error_handling (deinit_cuda , cuda12_prerequisite_check ):
237
+ def test_num_args_error_handling (deinit_context_function , cuda12_prerequisite_check ):
237
238
if not cuda12_prerequisite_check :
238
239
pytest .skip ("Test requires CUDA 12" )
239
240
src = "__global__ void foo(int a) { }"
@@ -243,6 +244,10 @@ def test_num_args_error_handling(deinit_cuda, cuda12_prerequisite_check):
243
244
name_expressions = ("foo" ,),
244
245
)
245
246
krn = mod .get_kernel ("foo" )
247
+ # Unset current context using function from conftest
248
+ deinit_context_function ()
249
+ # with no context, cuKernelGetParamInfo would report
250
+ # exception which we expect to handle by raising
246
251
with pytest .raises (CUDAError ):
247
252
# assignment resolves linter error "B018: useless expression"
248
253
_ = krn .num_arguments
0 commit comments