Skip to content

Commit c467550

Browse files
committed
address some review comments
1 parent 81086e0 commit c467550

File tree

2 files changed

+44
-57
lines changed

2 files changed

+44
-57
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class Linker:
8686

8787
__slots__ = ("_handle")
8888

89-
def __init__(self, options: LinkerOptions, object_codes = None):
89+
def __init__(self, *object_codes : ObjectCode, options: LinkerOptions = None):
9090
self._handle = None
9191
options = check_or_create_options(LinkerOptions, options, "Linker options")
9292
self._handle = nvjitlink.create(len(options.formatted_options), options.formatted_options)

cuda_core/tests/test_linker.py

Lines changed: 43 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,101 +28,88 @@ def compile_ltoir_functions(init_cuda):
2828
return object_code_a_ltoir, object_code_b_ltoir, object_code_c_ltoir
2929

3030

31-
def test_linker_init_valid_options():
32-
options = LinkerOptions(arch=ARCH)
33-
linker = Linker(options)
34-
assert linker.handle is not None
31+
@pytest.mark.parametrize("options", [
32+
LinkerOptions(arch=ARCH),
33+
LinkerOptions(arch=ARCH, max_register_count=32),
34+
LinkerOptions(arch=ARCH, time=True),
35+
LinkerOptions(arch=ARCH, verbose=True),
36+
LinkerOptions(arch=ARCH, optimization_level=3),
37+
LinkerOptions(arch=ARCH, debug=True),
38+
LinkerOptions(arch=ARCH, lineinfo=True),
39+
LinkerOptions(arch=ARCH, ftz=True),
40+
LinkerOptions(arch=ARCH, prec_div=True),
41+
LinkerOptions(arch=ARCH, prec_sqrt=True),
42+
LinkerOptions(arch=ARCH, fma=True),
43+
LinkerOptions(arch=ARCH, kernels_used=["kernel1"]),
44+
LinkerOptions(arch=ARCH, variables_used=["var1"]),
45+
LinkerOptions(arch=ARCH, optimize_unused_variables=True),
46+
LinkerOptions(arch=ARCH, xptxas=["-v"]),
47+
LinkerOptions(arch=ARCH, split_compile=0),
48+
LinkerOptions(arch=ARCH, split_compile_extended=1),
49+
LinkerOptions(arch=ARCH, jump_table_density=100),
50+
LinkerOptions(arch=ARCH, no_cache=True)
51+
])
52+
def test_linker_init(compile_ptx_functions, options):
53+
linker = Linker(options, object_codes=compile_ptx_functions)
54+
object_code = linker.link("cubin")
55+
assert isinstance(object_code, ObjectCode)
56+
3557

3658
def test_linker_init_invalid_arch():
3759
options = LinkerOptions(arch=None)
3860
with pytest.raises(ValueError):
3961
Linker(options)
4062

41-
def test_linker_init(compile_ptx_functions):
42-
combinations = [
43-
LinkerOptions(arch=ARCH),
44-
LinkerOptions(arch=ARCH, max_register_count=32),
45-
LinkerOptions(arch=ARCH, time=True),
46-
LinkerOptions(arch=ARCH, verbose=True),
47-
LinkerOptions(arch=ARCH, optimization_level=3),
48-
LinkerOptions(arch=ARCH, debug=True),
49-
LinkerOptions(arch=ARCH, lineinfo=True),
50-
LinkerOptions(arch=ARCH, ftz=True),
51-
LinkerOptions(arch=ARCH, prec_div=True),
52-
LinkerOptions(arch=ARCH, prec_sqrt=True),
53-
LinkerOptions(arch=ARCH, fma=True),
54-
LinkerOptions(arch=ARCH, kernels_used=["kernel1"]),
55-
LinkerOptions(arch=ARCH, variables_used=["var1"]),
56-
LinkerOptions(arch=ARCH, optimize_unused_variables=True),
57-
LinkerOptions(arch=ARCH, xptxas=["-v"]),
58-
LinkerOptions(arch=ARCH, split_compile=0),
59-
LinkerOptions(arch=ARCH, split_compile_extended=1),
60-
LinkerOptions(arch=ARCH, jump_table_density=100),
61-
LinkerOptions(arch=ARCH, no_cache=True)
62-
]
63-
64-
# Try the combinations, with and without providing object codes to the constructor
65-
for i, options in enumerate(combinations):
66-
linker = Linker(options, object_codes=compile_ptx_functions)
67-
object_code = linker.link("cubin")
68-
assert isinstance(object_code, ObjectCode)
6963

7064
def test_linker_add_code_object(compile_ptx_functions):
7165
options = LinkerOptions(arch=ARCH)
7266
linker = Linker(options)
73-
functions = compile_ptx_functions
74-
linker.add_code_object(functions[0])
75-
linker.add_code_object(functions[1])
76-
linker.add_code_object(functions[2])
67+
for functions in compile_ptx_functions:
68+
linker.add_code_object(functions)
69+
7770

7871
def test_linker_link_ptx(compile_ltoir_functions):
7972
options = LinkerOptions(arch=ARCH, link_time_optimization=True, ptx=True)
8073
linker = Linker(options)
81-
functions = compile_ltoir_functions
82-
linker.add_code_object(functions[0])
83-
linker.add_code_object(functions[1])
84-
linker.add_code_object(functions[2])
74+
for functions in compile_ltoir_functions:
75+
linker.add_code_object(functions)
8576
linked_code = linker.link("ptx")
8677
assert isinstance(linked_code, ObjectCode)
8778

79+
8880
def test_linker_link_cubin(compile_ptx_functions):
8981
options = LinkerOptions(arch=ARCH)
9082
linker = Linker(options)
91-
functions = compile_ptx_functions
92-
linker.add_code_object(functions[0])
93-
linker.add_code_object(functions[1])
94-
linker.add_code_object(functions[2])
83+
for functions in compile_ptx_functions:
84+
linker.add_code_object(functions)
9585
linked_code = linker.link("cubin")
9686
assert isinstance(linked_code, ObjectCode)
9787

88+
9889
def test_linker_link_invalid_target_type(compile_ptx_functions):
9990
options = LinkerOptions(arch=ARCH)
10091
linker = Linker(options)
101-
functions = compile_ptx_functions
102-
linker.add_code_object(functions[0])
103-
linker.add_code_object(functions[1])
104-
linker.add_code_object(functions[2])
92+
for functions in compile_ptx_functions:
93+
linker.add_code_object(functions)
10594
with pytest.raises(ValueError):
10695
linker.link("invalid_target")
10796

97+
10898
def test_linker_get_error_log(compile_ptx_functions):
10999
options = LinkerOptions(arch=ARCH)
110100
linker = Linker(options)
111-
functions = compile_ptx_functions
112-
linker.add_code_object(functions[0])
113-
linker.add_code_object(functions[1])
114-
linker.add_code_object(functions[2])
101+
for functions in compile_ptx_functions:
102+
linker.add_code_object(functions)
115103
linker.link("cubin")
116104
log = linker.get_error_log()
117105
assert isinstance(log, str)
118106

107+
119108
def test_linker_get_info_log(compile_ptx_functions):
120109
options = LinkerOptions(arch=ARCH)
121110
linker = Linker(options)
122-
functions = compile_ptx_functions
123-
linker.add_code_object(functions[0])
124-
linker.add_code_object(functions[1])
125-
linker.add_code_object(functions[2])
111+
for functions in compile_ptx_functions:
112+
linker.add_code_object(functions)
126113
linker.link("cubin")
127114
log = linker.get_info_log()
128115
assert isinstance(log, str)

0 commit comments

Comments
 (0)