1
1
from cuda .bindings import driver , nvrtc , runtime
2
2
from cuda .core .experimental import _utils
3
3
4
- err , _DRIVER_VERSION = driver .cuDriverGetVersion ()
5
- assert err == driver .CUresult .CUDA_SUCCESS
4
+ _BINDING_VERSION = _utils .get_binding_version ()
6
5
7
6
8
7
def test_driver_error_info ():
@@ -12,7 +11,7 @@ def test_driver_error_info():
12
11
try :
13
12
error = driver .CUresult (code )
14
13
except ValueError :
15
- if _DRIVER_VERSION >= 12000 :
14
+ if _BINDING_VERSION >= ( 12 , 0 ) :
16
15
assert code not in expl_dict
17
16
else :
18
17
assert code in expl_dict
@@ -25,8 +24,9 @@ def test_driver_error_info():
25
24
print (desc )
26
25
print (expl )
27
26
print ()
28
- stray_expl_codes = sorted (set (expl_dict .keys ()) - valid_codes )
29
- assert not stray_expl_codes
27
+ if _BINDING_VERSION >= (12 , 0 ):
28
+ extra_expl_codes = sorted (set (expl_dict .keys ()) - valid_codes )
29
+ assert not extra_expl_codes
30
30
missing_expl_codes = sorted (valid_codes - set (expl_dict .keys ()))
31
31
assert not missing_expl_codes
32
32
@@ -38,7 +38,7 @@ def test_runtime_error_info():
38
38
try :
39
39
error = runtime .cudaError_t (code )
40
40
except ValueError :
41
- if _DRIVER_VERSION >= 12000 :
41
+ if _BINDING_VERSION >= ( 12 , 0 ) :
42
42
assert code not in expl_dict
43
43
else :
44
44
assert code in expl_dict
@@ -51,8 +51,9 @@ def test_runtime_error_info():
51
51
print (desc )
52
52
print (expl )
53
53
print ()
54
- stray_expl_codes = sorted (set (expl_dict .keys ()) - valid_codes )
55
- assert not stray_expl_codes
54
+ if _BINDING_VERSION >= (12 , 0 ):
55
+ extra_expl_codes = sorted (set (expl_dict .keys ()) - valid_codes )
56
+ assert not extra_expl_codes
56
57
missing_expl_codes = sorted (valid_codes - set (expl_dict .keys ()))
57
58
assert not missing_expl_codes
58
59
0 commit comments