1
- # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2
- #
1
+ # Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
3
2
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4
3
5
4
import pytest
11
10
PTX_VERSIONS = ["5.0" , "6.4" , "7.0" , "8.5" ]
12
11
13
12
14
- def ptx_header (version , arch ):
15
- return f"""
16
- .version { version }
17
- .target { arch }
13
+ PTX_HEADER = """\
14
+ .version {VERSION}
15
+ .target {ARCH}
18
16
.address_size 64
19
17
"""
20
18
21
-
22
- ptx_kernel = """
19
+ PTX_KERNEL = """
23
20
.visible .entry _Z6kernelPi(
24
21
.param .u64 _Z6kernelPi_param_0
25
22
)
@@ -36,20 +33,21 @@ def ptx_header(version, arch):
36
33
}
37
34
"""
38
35
39
- minimal_ptx_kernel = """
40
- .func _MinimalKernel()
41
- {
42
- ret;
43
- }
44
- """
45
36
46
- ptx_kernel_bytes = [
47
- (ptx_header (version , arch ) + ptx_kernel ).encode ("utf-8" ) for version , arch in zip (PTX_VERSIONS , ARCHITECTURES )
48
- ]
49
- minimal_ptx_kernel_bytes = [
50
- (ptx_header (version , arch ) + minimal_ptx_kernel ).encode ("utf-8" )
51
- for version , arch in zip (PTX_VERSIONS , ARCHITECTURES )
52
- ]
37
+ def _build_arch_ptx_parametrized_callable ():
38
+ av = tuple (zip (ARCHITECTURES , PTX_VERSIONS ))
39
+ return pytest .mark .parametrize (
40
+ ("arch" , "ptx_bytes" ),
41
+ [(a , (PTX_HEADER .format (VERSION = v , ARCH = a ) + PTX_KERNEL ).encode ("utf-8" )) for a , v in av ],
42
+ ids = [f"{ a } _{ v } " for a , v in av ],
43
+ )
44
+
45
+
46
+ ARCH_PTX_PARAMETRIZED_CALLABLE = _build_arch_ptx_parametrized_callable ()
47
+
48
+
49
+ def arch_ptx_parametrized (func ):
50
+ return ARCH_PTX_PARAMETRIZED_CALLABLE (func )
53
51
54
52
55
53
def check_nvjitlink_usable ():
@@ -108,27 +106,27 @@ def test_complete_empty(option):
108
106
nvjitlink .destroy (handle )
109
107
110
108
111
- @pytest . mark . parametrize ( "option, ptx_bytes" , zip ( ARCHITECTURES , ptx_kernel_bytes ))
112
- def test_add_data (option , ptx_bytes ):
113
- handle = nvjitlink .create (1 , [f"-arch={ option } " ])
109
+ @arch_ptx_parametrized
110
+ def test_add_data (arch , ptx_bytes ):
111
+ handle = nvjitlink .create (1 , [f"-arch={ arch } " ])
114
112
nvjitlink .add_data (handle , nvjitlink .InputType .ANY , ptx_bytes , len (ptx_bytes ), "test_data" )
115
113
nvjitlink .complete (handle )
116
114
nvjitlink .destroy (handle )
117
115
118
116
119
- @pytest . mark . parametrize ( "option, ptx_bytes" , zip ( ARCHITECTURES , ptx_kernel_bytes ))
120
- def test_add_file (option , ptx_bytes , tmp_path ):
121
- handle = nvjitlink .create (1 , [f"-arch={ option } " ])
117
+ @arch_ptx_parametrized
118
+ def test_add_file (arch , ptx_bytes , tmp_path ):
119
+ handle = nvjitlink .create (1 , [f"-arch={ arch } " ])
122
120
file_path = tmp_path / "test_file.cubin"
123
121
file_path .write_bytes (ptx_bytes )
124
122
nvjitlink .add_file (handle , nvjitlink .InputType .ANY , str (file_path ))
125
123
nvjitlink .complete (handle )
126
124
nvjitlink .destroy (handle )
127
125
128
126
129
- @pytest .mark .parametrize ("option " , ARCHITECTURES )
130
- def test_get_error_log (option ):
131
- handle = nvjitlink .create (1 , [f"-arch={ option } " ])
127
+ @pytest .mark .parametrize ("arch " , ARCHITECTURES )
128
+ def test_get_error_log (arch ):
129
+ handle = nvjitlink .create (1 , [f"-arch={ arch } " ])
132
130
nvjitlink .complete (handle )
133
131
log_size = nvjitlink .get_error_log_size (handle )
134
132
log = bytearray (log_size )
@@ -137,9 +135,9 @@ def test_get_error_log(option):
137
135
nvjitlink .destroy (handle )
138
136
139
137
140
- @pytest . mark . parametrize ( "option, ptx_bytes" , zip ( ARCHITECTURES , ptx_kernel_bytes ))
141
- def test_get_info_log (option , ptx_bytes ):
142
- handle = nvjitlink .create (1 , [f"-arch={ option } " ])
138
+ @arch_ptx_parametrized
139
+ def test_get_info_log (arch , ptx_bytes ):
140
+ handle = nvjitlink .create (1 , [f"-arch={ arch } " ])
143
141
nvjitlink .add_data (handle , nvjitlink .InputType .ANY , ptx_bytes , len (ptx_bytes ), "test_data" )
144
142
nvjitlink .complete (handle )
145
143
log_size = nvjitlink .get_info_log_size (handle )
@@ -149,9 +147,9 @@ def test_get_info_log(option, ptx_bytes):
149
147
nvjitlink .destroy (handle )
150
148
151
149
152
- @pytest . mark . parametrize ( "option, ptx_bytes" , zip ( ARCHITECTURES , ptx_kernel_bytes ))
153
- def test_get_linked_cubin (option , ptx_bytes ):
154
- handle = nvjitlink .create (1 , [f"-arch={ option } " ])
150
+ @arch_ptx_parametrized
151
+ def test_get_linked_cubin (arch , ptx_bytes ):
152
+ handle = nvjitlink .create (1 , [f"-arch={ arch } " ])
155
153
nvjitlink .add_data (handle , nvjitlink .InputType .ANY , ptx_bytes , len (ptx_bytes ), "test_data" )
156
154
nvjitlink .complete (handle )
157
155
cubin_size = nvjitlink .get_linked_cubin_size (handle )
@@ -161,9 +159,9 @@ def test_get_linked_cubin(option, ptx_bytes):
161
159
nvjitlink .destroy (handle )
162
160
163
161
164
- @pytest .mark .parametrize ("option " , ARCHITECTURES )
165
- def test_get_linked_ptx (option , get_dummy_ltoir ):
166
- handle = nvjitlink .create (3 , [f"-arch={ option } " , "-lto" , "-ptx" ])
162
+ @pytest .mark .parametrize ("arch " , ARCHITECTURES )
163
+ def test_get_linked_ptx (arch , get_dummy_ltoir ):
164
+ handle = nvjitlink .create (3 , [f"-arch={ arch } " , "-lto" , "-ptx" ])
167
165
nvjitlink .add_data (handle , nvjitlink .InputType .LTOIR , get_dummy_ltoir , len (get_dummy_ltoir ), "test_data" )
168
166
nvjitlink .complete (handle )
169
167
ptx_size = nvjitlink .get_linked_ptx_size (handle )
0 commit comments