Skip to content

Commit 5408229

Browse files
committed
ensure program/linker have consistent backend and handle
1 parent 02a4178 commit 5408229

File tree

4 files changed

+29
-22
lines changed

4 files changed

+29
-22
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,14 @@ def _input_type_from_code_type(self, code_type: str):
484484

485485
@property
486486
def handle(self):
487-
"""Return the linker handle object."""
487+
"""Return the underlying handle object."""
488488
return self._mnff.handle
489489

490+
@property
491+
def backend(self) -> str:
492+
"""Return this Linker instance's underlying backend."""
493+
return "nvJitLink" if self._mnff.use_nvjitlink else "driver"
494+
490495
def close(self):
491496
"""Destroy this linker."""
492497
self._mnff.close()

cuda_core/cuda/core/experimental/_program.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
382382
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
383383

384384
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
385-
self._backend = "nvrtc"
385+
self._backend = "NVRTC"
386386
self._linker = None
387387

388388
elif code_type == "ptx":
@@ -391,7 +391,7 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
391391
self._linker = Linker(
392392
ObjectCode._init(code.encode(), code_type), options=self._translate_program_options(options)
393393
)
394-
self._backend = "linker"
394+
self._backend = self._linker.backend
395395
else:
396396
raise NotImplementedError
397397

@@ -445,9 +445,9 @@ def compile(self, target_type, name_expressions=(), logs=None):
445445
446446
"""
447447
if target_type not in self._supported_target_type:
448-
raise NotImplementedError
448+
raise ValueError(f"the target type {target_type} is not supported")
449449

450-
if self._backend == "nvrtc":
450+
if self._backend == "NVRTC":
451451
if target_type == "ptx" and not self._can_load_generated_ptx():
452452
warn(
453453
"The CUDA driver version is older than the backend version. "
@@ -489,15 +489,15 @@ def compile(self, target_type, name_expressions=(), logs=None):
489489

490490
return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping)
491491

492-
if self._backend == "linker":
492+
if self._backend in ("nvJitLink", "driver"):
493493
return self._linker.link(target_type)
494494

495495
@property
496-
def backend(self):
497-
"""Return the backend type string associated with this program."""
496+
def backend(self) -> str:
497+
"""Return this Program instance's underlying backend."""
498498
return self._backend
499499

500500
@property
501501
def handle(self):
502-
"""Return the program handle object."""
502+
"""Return the underlying handle object."""
503503
return self._mnff.handle

cuda_core/tests/test_linker.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
device_function_b = "__device__ int B() { return 0; }"
1818
device_function_c = "__device__ int C(int a, int b) { return a + b; }"
1919

20-
culink_backend = _linker._decide_nvjitlink_or_driver()
21-
if not culink_backend:
20+
is_culink_backend = _linker._decide_nvjitlink_or_driver()
21+
if not is_culink_backend:
2222
from cuda.bindings import nvjitlink
2323

2424

@@ -54,7 +54,7 @@ def compile_ltoir_functions(init_cuda):
5454
LinkerOptions(arch=ARCH, debug=True),
5555
LinkerOptions(arch=ARCH, lineinfo=True),
5656
]
57-
if not culink_backend:
57+
if not is_culink_backend:
5858
options += [
5959
LinkerOptions(arch=ARCH, time=True),
6060
LinkerOptions(arch=ARCH, optimize_unused_variables=True),
@@ -85,24 +85,25 @@ def test_linker_init(compile_ptx_functions, options):
8585
linker = Linker(*compile_ptx_functions, options=options)
8686
object_code = linker.link("cubin")
8787
assert isinstance(object_code, ObjectCode)
88+
assert linker.backend == ("driver" if is_culink_backend else "nvJitLink")
8889

8990

9091
def test_linker_init_invalid_arch(compile_ptx_functions):
91-
err = AttributeError if culink_backend else nvjitlink.nvJitLinkError
92+
err = AttributeError if is_culink_backend else nvjitlink.nvJitLinkError
9293
with pytest.raises(err):
9394
options = LinkerOptions(arch="99", ptx=True)
9495
Linker(*compile_ptx_functions, options=options)
9596

9697

97-
@pytest.mark.skipif(culink_backend, reason="culink does not support ptx option")
98+
@pytest.mark.skipif(is_culink_backend, reason="culink does not support ptx option")
9899
def test_linker_link_ptx_nvjitlink(compile_ltoir_functions):
99100
options = LinkerOptions(arch=ARCH, link_time_optimization=True, ptx=True)
100101
linker = Linker(*compile_ltoir_functions, options=options)
101102
linked_code = linker.link("ptx")
102103
assert isinstance(linked_code, ObjectCode)
103104

104105

105-
@pytest.mark.skipif(not culink_backend, reason="nvjitlink requires lto for ptx linking")
106+
@pytest.mark.skipif(not is_culink_backend, reason="nvjitlink requires lto for ptx linking")
106107
def test_linker_link_ptx_culink(compile_ptx_functions):
107108
options = LinkerOptions(arch=ARCH)
108109
linker = Linker(*compile_ptx_functions, options=options)

cuda_core/tests/test_program.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from cuda.core.experimental._module import Kernel, ObjectCode
1515
from cuda.core.experimental._program import Program, ProgramOptions
1616

17+
is_culink_backend = _linker._decide_nvjitlink_or_driver()
18+
1719

1820
@pytest.fixture(scope="module")
1921
def ptx_code_object():
@@ -50,7 +52,7 @@ def ptx_code_object():
5052
def test_cpp_program_with_various_options(init_cuda, options):
5153
code = 'extern "C" __global__ void my_kernel() {}'
5254
program = Program(code, "c++", options)
53-
assert program.backend == "nvrtc"
55+
assert program.backend == "NVRTC"
5456
program.compile("ptx")
5557
program.close()
5658
assert program.handle is None
@@ -65,8 +67,7 @@ def test_cpp_program_with_various_options(init_cuda, options):
6567
ProgramOptions(prec_sqrt=True),
6668
ProgramOptions(fma=True),
6769
]
68-
if not _linker._decide_nvjitlink_or_driver():
69-
print("Using nvjitlink as the backend because decide() returned false")
70+
if not is_culink_backend:
7071
options += [
7172
ProgramOptions(time=True),
7273
ProgramOptions(split_compile=True),
@@ -76,7 +77,7 @@ def test_cpp_program_with_various_options(init_cuda, options):
7677
@pytest.mark.parametrize("options", options)
7778
def test_ptx_program_with_various_options(init_cuda, ptx_code_object, options):
7879
program = Program(ptx_code_object._module.decode(), "ptx", options=options)
79-
assert program.backend == "linker"
80+
assert program.backend == ("driver" if is_culink_backend else "nvJitLink")
8081
program.compile("cubin")
8182
program.close()
8283
assert program.handle is None
@@ -85,7 +86,7 @@ def test_ptx_program_with_various_options(init_cuda, ptx_code_object, options):
8586
def test_program_init_valid_code_type():
8687
code = 'extern "C" __global__ void my_kernel() {}'
8788
program = Program(code, "c++")
88-
assert program.backend == "nvrtc"
89+
assert program.backend == "NVRTC"
8990
assert program.handle is not None
9091

9192

@@ -125,14 +126,14 @@ def test_program_compile_valid_target_type(init_cuda):
125126
def test_program_compile_invalid_target_type():
126127
code = 'extern "C" __global__ void my_kernel() {}'
127128
program = Program(code, "c++")
128-
with pytest.raises(NotImplementedError):
129+
with pytest.raises(ValueError):
129130
program.compile("invalid_target")
130131

131132

132133
def test_program_backend_property():
133134
code = 'extern "C" __global__ void my_kernel() {}'
134135
program = Program(code, "c++")
135-
assert program.backend == "nvrtc"
136+
assert program.backend == "NVRTC"
136137

137138

138139
def test_program_handle_property():

0 commit comments

Comments
 (0)