2
2
#
3
3
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4
4
5
- from contextlib import contextmanager
5
+ from contextlib import contextmanager , nullcontext
6
6
7
7
import pytest
8
8
20
20
device_function_c = "__device__ int C(int a, int b) { return a + b; }"
21
21
22
22
culink_backend = _linker ._decide_nvjitlink_or_driver ()
23
+ skip_options = nullcontext
23
24
if not culink_backend :
24
25
from cuda .bindings import nvjitlink
25
26
27
+ @contextmanager
28
+ def skip_version_specific_linker_options ():
29
+ if culink_backend :
30
+ return
31
+ try :
32
+ yield
33
+ except nvjitlink .nvJitLinkError as e :
34
+ if e .status == nvjitlink .Result .ERROR_UNRECOGNIZED_OPTION :
35
+ pytest .skip ("current nvjitlink version does not support the option provided" )
36
+ else :
37
+ raise
38
+
39
+ skip_options = skip_version_specific_linker_options
40
+
26
41
27
42
@pytest .fixture (scope = "function" )
28
43
def compile_ptx_functions (init_cuda ):
@@ -44,17 +59,6 @@ def compile_ltoir_functions(init_cuda):
44
59
return object_code_a_ltoir , object_code_b_ltoir , object_code_c_ltoir
45
60
46
61
47
- @contextmanager
48
- def skip_version_specific_linker_options ():
49
- if culink_backend :
50
- return
51
- try :
52
- yield
53
- except nvjitlink .nvJitLinkError as e :
54
- if e .status == nvjitlink .Result .ERROR_UNRECOGNIZED_OPTION :
55
- pytest .skip ("current nvjitlink version does not support the option provided" )
56
-
57
-
58
62
culink_options = [
59
63
LinkerOptions (arch = ARCH , verbose = True ),
60
64
LinkerOptions (arch = ARCH , max_register_count = 32 ),
@@ -87,11 +91,11 @@ def skip_version_specific_linker_options():
87
91
],
88
92
)
89
93
def test_linker_init (compile_ptx_functions , options ):
90
- with skip_version_specific_linker_options ():
94
+ with skip_options ():
91
95
linker = Linker (* compile_ptx_functions , options = options )
92
96
93
- object_code = linker .link ("cubin" )
94
- assert isinstance (object_code , ObjectCode )
97
+ object_code = linker .link ("cubin" )
98
+ assert isinstance (object_code , ObjectCode )
95
99
96
100
97
101
def test_linker_init_invalid_arch ():
0 commit comments