Skip to content

Commit 66b03fc

Browse files
committed
address comments
1 parent 6e5114e commit 66b03fc

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def _decide_nvjitlink_or_driver():
3232
_driver_ver = handle_return(cuda.cuDriverGetVersion())
3333
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
3434
try:
35+
raise ImportError
3536
from cuda.bindings import nvjitlink as _nvjitlink
3637
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
3738
except ImportError:

cuda_core/tests/test_linker.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5-
from contextlib import contextmanager
5+
from contextlib import contextmanager, nullcontext
66

77
import pytest
88

@@ -20,9 +20,24 @@
2020
device_function_c = "__device__ int C(int a, int b) { return a + b; }"
2121

2222
culink_backend = _linker._decide_nvjitlink_or_driver()
23+
skip_options = nullcontext
2324
if not culink_backend:
2425
from cuda.bindings import nvjitlink
2526

27+
@contextmanager
28+
def skip_version_specific_linker_options():
29+
if culink_backend:
30+
return
31+
try:
32+
yield
33+
except nvjitlink.nvJitLinkError as e:
34+
if e.status == nvjitlink.Result.ERROR_UNRECOGNIZED_OPTION:
35+
pytest.skip("current nvjitlink version does not support the option provided")
36+
else:
37+
raise
38+
39+
skip_options = skip_version_specific_linker_options
40+
2641

2742
@pytest.fixture(scope="function")
2843
def compile_ptx_functions(init_cuda):
@@ -44,17 +59,6 @@ def compile_ltoir_functions(init_cuda):
4459
return object_code_a_ltoir, object_code_b_ltoir, object_code_c_ltoir
4560

4661

47-
@contextmanager
48-
def skip_version_specific_linker_options():
49-
if culink_backend:
50-
return
51-
try:
52-
yield
53-
except nvjitlink.nvJitLinkError as e:
54-
if e.status == nvjitlink.Result.ERROR_UNRECOGNIZED_OPTION:
55-
pytest.skip("current nvjitlink version does not support the option provided")
56-
57-
5862
culink_options = [
5963
LinkerOptions(arch=ARCH, verbose=True),
6064
LinkerOptions(arch=ARCH, max_register_count=32),
@@ -87,11 +91,11 @@ def skip_version_specific_linker_options():
8791
],
8892
)
8993
def test_linker_init(compile_ptx_functions, options):
90-
with skip_version_specific_linker_options():
94+
with skip_options():
9195
linker = Linker(*compile_ptx_functions, options=options)
9296

93-
object_code = linker.link("cubin")
94-
assert isinstance(object_code, ObjectCode)
97+
object_code = linker.link("cubin")
98+
assert isinstance(object_code, ObjectCode)
9599

96100

97101
def test_linker_init_invalid_arch():

0 commit comments

Comments
 (0)