Skip to content

Commit a7f8c30

Browse files
committed
WIP: enable cuLink APIs from driver
1 parent 5207558 commit a7f8c30

File tree

2 files changed

+209
-76
lines changed

2 files changed

+209
-76
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 193 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,64 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
import ctypes
56
import weakref
67
from dataclasses import dataclass
78
from typing import List, Optional
89

9-
from cuda.bindings import nvjitlink
10+
from cuda import cuda
1011
from cuda.core.experimental._module import ObjectCode
11-
from cuda.core.experimental._utils import check_or_create_options
12+
from cuda.core.experimental._utils import check_or_create_options, handle_return
13+
14+
# TODO: revisit this treatment for py313t builds
15+
_driver = None # populated if nvJitLink cannot be used
16+
_driver_input_types = None # populated if nvJitLink cannot be used
17+
_driver_ver = None
18+
_inited = False
19+
_nvjitlink = None # populated if nvJitLink can be used
20+
_nvjitlink_input_types = None # populated if nvJitLink cannot be used
21+
22+
23+
def _lazy_init():
24+
global _inited
25+
if _inited:
26+
return
27+
28+
global _driver, _driver_input_types, _driver_ver, _nvjitlink, _nvjitlink_input_types
29+
_driver_ver = handle_return(cuda.cuDriverGetVersion())
30+
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
31+
try:
32+
from cuda.bindings import nvjitlink
33+
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
34+
except ImportError:
35+
# binding is not available
36+
nvjitlink = None
37+
else:
38+
if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0:
39+
# binding is available, but nvJitLink is not installed
40+
nvjitlink = None
41+
elif _driver_ver > nvjitlink.version():
42+
# TODO: nvJitLink is not new enough, warn?
43+
pass
44+
if nvjitlink:
45+
_nvjitlink = nvjitlink
46+
_nvjitlink_input_types = {
47+
"ptx": _nvjitlink.InputType.PTX,
48+
"cubin": _nvjitlink.InputType.CUBIN,
49+
"fatbin": _nvjitlink.InputType.FATBIN,
50+
"ltoir": _nvjitlink.InputType.LTOIR,
51+
"object": _nvjitlink.InputType.OBJECT,
52+
}
53+
else:
54+
from cuda import cuda as _driver
55+
56+
_driver_input_types = {
57+
"ptx": _driver.CUjitInputType.CU_JIT_INPUT_PTX,
58+
"cubin": _driver.CUjitInputType.CU_JIT_INPUT_CUBIN,
59+
"fatbin": _driver.CUjitInputType.CU_JIT_INPUT_FATBINARY,
60+
"object": _driver.CUjitInputType.CU_JIT_INPUT_OBJECT,
61+
}
62+
_inited = True
1263

1364

1465
@dataclass
@@ -146,7 +197,14 @@ class LinkerOptions:
146197
no_cache: Optional[bool] = None
147198

148199
def __post_init__(self):
200+
_lazy_init()
149201
self.formatted_options = []
202+
if _nvjitlink:
203+
self._init_nvjitlink()
204+
else:
205+
self._init_driver()
206+
207+
def _init_nvjitlink(self):
150208
if self.arch is not None:
151209
self.formatted_options.append(f"-arch={self.arch}")
152210
if self.max_register_count is not None:
@@ -191,6 +249,67 @@ def __post_init__(self):
191249
if self.no_cache is not None:
192250
self.formatted_options.append("-no-cache")
193251

252+
def _init_driver(self):
253+
self.option_keys = []
254+
# allocate 4 KiB each for info/error logs
255+
size = 4194304
256+
self.formatted_options.extend((bytearray(size), size, bytearray(size), size))
257+
self.option_keys.extend(
258+
(
259+
_driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
260+
_driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
261+
_driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER,
262+
_driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
263+
)
264+
)
265+
266+
if self.arch is not None:
267+
arch = self.arch.split("_")[-1].upper()
268+
self.formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
269+
self.option_keys.append(_driver.CUjit_option.CU_JIT_TARGET)
270+
# if self.max_register_count is not None:
271+
# self.formatted_options.append(f"-maxrregcount={self.max_register_count}")
272+
# if self.time is not None:
273+
# self.formatted_options.append("-time")
274+
if self.verbose is not None:
275+
self.formatted_options.append(1) # ctypes.c_int32(1))
276+
self.option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
277+
# if self.link_time_optimization is not None:
278+
# self.formatted_options.append("-lto")
279+
# if self.ptx is not None:
280+
# self.formatted_options.append("-ptx")
281+
# if self.optimization_level is not None:
282+
# self.formatted_options.append(f"-O{self.optimization_level}")
283+
# if self.debug is not None:
284+
# self.formatted_options.append("-g")
285+
# if self.lineinfo is not None:
286+
# self.formatted_options.append("-lineinfo")
287+
# if self.ftz is not None:
288+
# self.formatted_options.append(f"-ftz={'true' if self.ftz else 'false'}")
289+
# if self.prec_div is not None:
290+
# self.formatted_options.append(f"-prec-div={'true' if self.prec_div else 'false'}")
291+
# if self.prec_sqrt is not None:
292+
# self.formatted_options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}")
293+
# if self.fma is not None:
294+
# self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}")
295+
# if self.kernels_used is not None:
296+
# for kernel in self.kernels_used:
297+
# self.formatted_options.append(f"-kernels-used={kernel}")
298+
# if self.variables_used is not None:
299+
# for variable in self.variables_used:
300+
# self.formatted_options.append(f"-variables-used={variable}")
301+
# if self.optimize_unused_variables is not None:
302+
# self.formatted_options.append("-optimize-unused-variables")
303+
# if self.xptxas is not None:
304+
# for opt in self.xptxas:
305+
# self.formatted_options.append(f"-Xptxas={opt}")
306+
# if self.split_compile is not None:
307+
# self.formatted_options.append(f"-split-compile={self.split_compile}")
308+
# if self.split_compile_extended is not None:
309+
# self.formatted_options.append(f"-split-compile-extended={self.split_compile_extended}")
310+
# if self.no_cache is not None:
311+
# self.formatted_options.append("-no-cache")
312+
194313

195314
class Linker:
196315
"""
@@ -202,102 +321,116 @@ class Linker:
202321
One or more ObjectCode objects to be linked.
203322
options : LinkerOptions, optional
204323
Options for the linker. If not provided, default options will be used.
205-
206-
Attributes
207-
----------
208-
_options : LinkerOptions
209-
The options used for the linker.
210-
_handle : handle
211-
The handle to the linker created by nvjitlink.
212-
213-
Methods
214-
-------
215-
_add_code_object(object_code)
216-
Adds an object code to the linker.
217-
close()
218-
Closes the linker and releases resources.
219324
"""
220325

221326
class _MembersNeededForFinalize:
222-
__slots__ = ("handle",)
327+
__slots__ = ("handle", "use_nvjitlink")
223328

224-
def __init__(self, program_obj, handle):
329+
def __init__(self, program_obj, handle, use_nvjitlink):
225330
self.handle = handle
331+
self.use_nvjitlink = use_nvjitlink
226332
weakref.finalize(program_obj, self.close)
227333

228334
def close(self):
229335
if self.handle is not None:
230-
nvjitlink.destroy(self.handle)
336+
if self.use_nvjitlink:
337+
_nvjitlink.destroy(self.handle)
338+
else:
339+
handle_return(_driver.cuLinkDestroy(self.handle))
231340
self.handle = None
232341

233342
__slots__ = ("__weakref__", "_mnff", "_options")
234343

235344
def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None):
236-
self._options = options = check_or_create_options(LinkerOptions, options, "Linker options")
237-
self._mnff = Linker._MembersNeededForFinalize(
238-
self, nvjitlink.create(len(options.formatted_options), options.formatted_options)
239-
)
240-
241345
if len(object_codes) == 0:
242346
raise ValueError("At least one ObjectCode object must be provided")
243347

348+
self._options = options = check_or_create_options(LinkerOptions, options, "Linker options")
349+
if _nvjitlink:
350+
handle = _nvjitlink.create(len(options.formatted_options), options.formatted_options)
351+
use_nvjitlink = True
352+
else:
353+
handle = handle_return(
354+
_driver.cuLinkCreate(len(options.formatted_options), options.option_keys, options.formatted_options)
355+
)
356+
use_nvjitlink = False
357+
self._mnff = Linker._MembersNeededForFinalize(self, handle, use_nvjitlink)
358+
244359
for code in object_codes:
245360
assert isinstance(code, ObjectCode)
246361
self._add_code_object(code)
247362

248363
def _add_code_object(self, object_code: ObjectCode):
249364
data = object_code._module
250365
assert isinstance(data, bytes)
251-
nvjitlink.add_data(
252-
self._mnff.handle,
253-
self._input_type_from_code_type(object_code._code_type),
254-
data,
255-
len(data),
256-
f"{object_code._handle}_{object_code._code_type}",
257-
)
258-
259-
_get_linked_methods = {
260-
"cubin": (nvjitlink.get_linked_cubin_size, nvjitlink.get_linked_cubin),
261-
"ptx": (nvjitlink.get_linked_ptx_size, nvjitlink.get_linked_ptx),
262-
}
366+
if _nvjitlink:
367+
_nvjitlink.add_data(
368+
self._mnff.handle,
369+
self._input_type_from_code_type(object_code._code_type),
370+
data,
371+
len(data),
372+
f"{object_code._handle}_{object_code._code_type}",
373+
)
374+
else:
375+
handle_return(
376+
_driver.cuLinkAddData(
377+
self._mnff.handle,
378+
self._input_type_from_code_type(object_code._code_type),
379+
data,
380+
len(data),
381+
f"{object_code._handle}_{object_code._code_type}".encode(),
382+
0,
383+
None,
384+
None,
385+
)
386+
)
263387

264388
def link(self, target_type) -> ObjectCode:
265-
nvjitlink.complete(self._mnff.handle)
266-
get_linked = self._get_linked_methods.get(target_type)
267-
if get_linked is None:
389+
if target_type not in ("cubin", "ptx"):
268390
raise ValueError(f"Unsupported target type: {target_type}")
391+
if _nvjitlink:
392+
_nvjitlink.complete(self._mnff.handle)
393+
if target_type == "cubin":
394+
get_size = _nvjitlink.get_linked_cubin_size
395+
get_code = _nvjitlink.get_linked_cubin
396+
else:
397+
get_size = _nvjitlink.get_linked_ptx_size
398+
get_code = _nvjitlink.get_linked_ptx
269399

270-
get_size, get_code = get_linked
271-
size = get_size(self._mnff.handle)
272-
code = bytearray(size)
273-
get_code(self._mnff.handle, code)
400+
size = get_size(self._mnff.handle)
401+
code = bytearray(size)
402+
get_code(self._mnff.handle, code)
403+
else:
404+
addr, size = handle_return(_driver.cuLinkComplete(self._mnff.handle))
405+
code = (ctypes.c_char * size).from_address(addr)
274406

275407
return ObjectCode(bytes(code), target_type)
276408

277409
def get_error_log(self) -> str:
278-
log_size = nvjitlink.get_error_log_size(self._mnff.handle)
279-
log = bytearray(log_size)
280-
nvjitlink.get_error_log(self._mnff.handle, log)
410+
if _nvjitlink:
411+
log_size = _nvjitlink.get_error_log_size(self._mnff.handle)
412+
log = bytearray(log_size)
413+
_nvjitlink.get_error_log(self._mnff.handle, log)
414+
else:
415+
log = self._options.formatted_options[2]
281416
return log.decode()
282417

283418
def get_info_log(self) -> str:
284-
log_size = nvjitlink.get_info_log_size(self._mnff.handle)
285-
log = bytearray(log_size)
286-
nvjitlink.get_info_log(self._mnff.handle, log)
419+
if _nvjitlink:
420+
log_size = _nvjitlink.get_info_log_size(self._mnff.handle)
421+
log = bytearray(log_size)
422+
_nvjitlink.get_info_log(self._mnff.handle, log)
423+
else:
424+
log = self._options.formatted_options[0]
287425
return log.decode()
288426

289-
_input_types = {
290-
"ptx": nvjitlink.InputType.PTX,
291-
"cubin": nvjitlink.InputType.CUBIN,
292-
"fatbin": nvjitlink.InputType.FATBIN,
293-
"ltoir": nvjitlink.InputType.LTOIR,
294-
"object": nvjitlink.InputType.OBJECT,
295-
}
296-
297-
def _input_type_from_code_type(self, code_type: str) -> nvjitlink.InputType:
427+
def _input_type_from_code_type(self, code_type: str):
298428
# this list is based on the supported values for code_type in the ObjectCode class definition.
299-
# nvjitlink supports other options for input type
300-
input_type = self._input_types.get(code_type)
429+
# nvJitLink/driver support other options for input type
430+
if _nvjitlink:
431+
input_type = _nvjitlink_input_types.get(code_type)
432+
else:
433+
input_type = _driver_input_types.get(code_type)
301434

302435
if input_type is None:
303436
raise ValueError(f"Unknown code_type associated with ObjectCode: {code_type}")

cuda_core/tests/test_linker.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,23 @@ def compile_ltoir_functions(init_cuda):
3131
"options",
3232
[
3333
LinkerOptions(arch=ARCH),
34-
LinkerOptions(arch=ARCH, max_register_count=32),
35-
LinkerOptions(arch=ARCH, time=True),
34+
# LinkerOptions(arch=ARCH, max_register_count=32),
35+
# LinkerOptions(arch=ARCH, time=True),
3636
LinkerOptions(arch=ARCH, verbose=True),
37-
LinkerOptions(arch=ARCH, optimization_level=3),
38-
LinkerOptions(arch=ARCH, debug=True),
39-
LinkerOptions(arch=ARCH, lineinfo=True),
40-
LinkerOptions(arch=ARCH, ftz=True),
41-
LinkerOptions(arch=ARCH, prec_div=True),
42-
LinkerOptions(arch=ARCH, prec_sqrt=True),
43-
LinkerOptions(arch=ARCH, fma=True),
44-
LinkerOptions(arch=ARCH, kernels_used=["kernel1"]),
45-
LinkerOptions(arch=ARCH, variables_used=["var1"]),
46-
LinkerOptions(arch=ARCH, optimize_unused_variables=True),
47-
LinkerOptions(arch=ARCH, xptxas=["-v"]),
48-
LinkerOptions(arch=ARCH, split_compile=0),
49-
LinkerOptions(arch=ARCH, split_compile_extended=1),
50-
LinkerOptions(arch=ARCH, no_cache=True),
37+
# LinkerOptions(arch=ARCH, optimization_level=3),
38+
# LinkerOptions(arch=ARCH, debug=True),
39+
# LinkerOptions(arch=ARCH, lineinfo=True),
40+
# LinkerOptions(arch=ARCH, ftz=True),
41+
# LinkerOptions(arch=ARCH, prec_div=True),
42+
# LinkerOptions(arch=ARCH, prec_sqrt=True),
43+
# LinkerOptions(arch=ARCH, fma=True),
44+
# LinkerOptions(arch=ARCH, kernels_used=["kernel1"]),
45+
# LinkerOptions(arch=ARCH, variables_used=["var1"]),
46+
# LinkerOptions(arch=ARCH, optimize_unused_variables=True),
47+
# LinkerOptions(arch=ARCH, xptxas=["-v"]),
48+
# LinkerOptions(arch=ARCH, split_compile=0),
49+
# LinkerOptions(arch=ARCH, split_compile_extended=1),
50+
# LinkerOptions(arch=ARCH, no_cache=True),
5151
],
5252
)
5353
def test_linker_init(compile_ptx_functions, options):

0 commit comments

Comments
 (0)