Skip to content

Commit 8a49923

Browse files
authored
Make test_nvjitlink.py less noisy. (#666)
* Remove unused code (minimal_ptx_kernel) * Make test_nvjitlink.py less noisy.
1 parent d1c8b0f commit 8a49923

File tree

1 file changed

+37
-39
lines changed

1 file changed

+37
-39
lines changed

cuda_bindings/tests/test_nvjitlink.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2-
#
1+
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
32
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
43

54
import pytest
@@ -11,15 +10,13 @@
1110
PTX_VERSIONS = ["5.0", "6.4", "7.0", "8.5"]
1211

1312

14-
def ptx_header(version, arch):
15-
return f"""
16-
.version {version}
17-
.target {arch}
13+
PTX_HEADER = """\
14+
.version {VERSION}
15+
.target {ARCH}
1816
.address_size 64
1917
"""
2018

21-
22-
ptx_kernel = """
19+
PTX_KERNEL = """
2320
.visible .entry _Z6kernelPi(
2421
.param .u64 _Z6kernelPi_param_0
2522
)
@@ -36,20 +33,21 @@ def ptx_header(version, arch):
3633
}
3734
"""
3835

39-
minimal_ptx_kernel = """
40-
.func _MinimalKernel()
41-
{
42-
ret;
43-
}
44-
"""
4536

46-
ptx_kernel_bytes = [
47-
(ptx_header(version, arch) + ptx_kernel).encode("utf-8") for version, arch in zip(PTX_VERSIONS, ARCHITECTURES)
48-
]
49-
minimal_ptx_kernel_bytes = [
50-
(ptx_header(version, arch) + minimal_ptx_kernel).encode("utf-8")
51-
for version, arch in zip(PTX_VERSIONS, ARCHITECTURES)
52-
]
37+
def _build_arch_ptx_parametrized_callable():
38+
av = tuple(zip(ARCHITECTURES, PTX_VERSIONS))
39+
return pytest.mark.parametrize(
40+
("arch", "ptx_bytes"),
41+
[(a, (PTX_HEADER.format(VERSION=v, ARCH=a) + PTX_KERNEL).encode("utf-8")) for a, v in av],
42+
ids=[f"{a}_{v}" for a, v in av],
43+
)
44+
45+
46+
ARCH_PTX_PARAMETRIZED_CALLABLE = _build_arch_ptx_parametrized_callable()
47+
48+
49+
def arch_ptx_parametrized(func):
50+
return ARCH_PTX_PARAMETRIZED_CALLABLE(func)
5351

5452

5553
def check_nvjitlink_usable():
@@ -108,27 +106,27 @@ def test_complete_empty(option):
108106
nvjitlink.destroy(handle)
109107

110108

111-
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
112-
def test_add_data(option, ptx_bytes):
113-
handle = nvjitlink.create(1, [f"-arch={option}"])
109+
@arch_ptx_parametrized
110+
def test_add_data(arch, ptx_bytes):
111+
handle = nvjitlink.create(1, [f"-arch={arch}"])
114112
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
115113
nvjitlink.complete(handle)
116114
nvjitlink.destroy(handle)
117115

118116

119-
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
120-
def test_add_file(option, ptx_bytes, tmp_path):
121-
handle = nvjitlink.create(1, [f"-arch={option}"])
117+
@arch_ptx_parametrized
118+
def test_add_file(arch, ptx_bytes, tmp_path):
119+
handle = nvjitlink.create(1, [f"-arch={arch}"])
122120
file_path = tmp_path / "test_file.cubin"
123121
file_path.write_bytes(ptx_bytes)
124122
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
125123
nvjitlink.complete(handle)
126124
nvjitlink.destroy(handle)
127125

128126

129-
@pytest.mark.parametrize("option", ARCHITECTURES)
130-
def test_get_error_log(option):
131-
handle = nvjitlink.create(1, [f"-arch={option}"])
127+
@pytest.mark.parametrize("arch", ARCHITECTURES)
128+
def test_get_error_log(arch):
129+
handle = nvjitlink.create(1, [f"-arch={arch}"])
132130
nvjitlink.complete(handle)
133131
log_size = nvjitlink.get_error_log_size(handle)
134132
log = bytearray(log_size)
@@ -137,9 +135,9 @@ def test_get_error_log(option):
137135
nvjitlink.destroy(handle)
138136

139137

140-
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
141-
def test_get_info_log(option, ptx_bytes):
142-
handle = nvjitlink.create(1, [f"-arch={option}"])
138+
@arch_ptx_parametrized
139+
def test_get_info_log(arch, ptx_bytes):
140+
handle = nvjitlink.create(1, [f"-arch={arch}"])
143141
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
144142
nvjitlink.complete(handle)
145143
log_size = nvjitlink.get_info_log_size(handle)
@@ -149,9 +147,9 @@ def test_get_info_log(option, ptx_bytes):
149147
nvjitlink.destroy(handle)
150148

151149

152-
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
153-
def test_get_linked_cubin(option, ptx_bytes):
154-
handle = nvjitlink.create(1, [f"-arch={option}"])
150+
@arch_ptx_parametrized
151+
def test_get_linked_cubin(arch, ptx_bytes):
152+
handle = nvjitlink.create(1, [f"-arch={arch}"])
155153
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
156154
nvjitlink.complete(handle)
157155
cubin_size = nvjitlink.get_linked_cubin_size(handle)
@@ -161,9 +159,9 @@ def test_get_linked_cubin(option, ptx_bytes):
161159
nvjitlink.destroy(handle)
162160

163161

164-
@pytest.mark.parametrize("option", ARCHITECTURES)
165-
def test_get_linked_ptx(option, get_dummy_ltoir):
166-
handle = nvjitlink.create(3, [f"-arch={option}", "-lto", "-ptx"])
162+
@pytest.mark.parametrize("arch", ARCHITECTURES)
163+
def test_get_linked_ptx(arch, get_dummy_ltoir):
164+
handle = nvjitlink.create(3, [f"-arch={arch}", "-lto", "-ptx"])
167165
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data")
168166
nvjitlink.complete(handle)
169167
ptx_size = nvjitlink.get_linked_ptx_size(handle)

0 commit comments

Comments
 (0)