2
2
#
3
3
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4
4
5
+ import ctypes
5
6
import weakref
6
7
from dataclasses import dataclass
7
8
from typing import List , Optional
8
9
9
- from cuda . bindings import nvjitlink
10
+ from cuda import cuda
10
11
from cuda .core .experimental ._module import ObjectCode
11
- from cuda .core .experimental ._utils import check_or_create_options
12
+ from cuda .core .experimental ._utils import check_or_create_options , handle_return
13
+
14
+ # TODO: revisit this treatment for py313t builds
15
+ _driver = None # populated if nvJitLink cannot be used
16
+ _driver_input_types = None # populated if nvJitLink cannot be used
17
+ _driver_ver = None
18
+ _inited = False
19
+ _nvjitlink = None # populated if nvJitLink can be used
20
+ _nvjitlink_input_types = None # populated if nvJitLink cannot be used
21
+
22
+
23
+ def _lazy_init ():
24
+ global _inited
25
+ if _inited :
26
+ return
27
+
28
+ global _driver , _driver_input_types , _driver_ver , _nvjitlink , _nvjitlink_input_types
29
+ _driver_ver = handle_return (cuda .cuDriverGetVersion ())
30
+ _driver_ver = (_driver_ver // 1000 , (_driver_ver % 1000 ) // 10 )
31
+ try :
32
+ from cuda .bindings import nvjitlink
33
+ from cuda .bindings ._internal import nvjitlink as inner_nvjitlink
34
+ except ImportError :
35
+ # binding is not available
36
+ nvjitlink = None
37
+ else :
38
+ if inner_nvjitlink ._inspect_function_pointer ("__nvJitLinkVersion" ) == 0 :
39
+ # binding is available, but nvJitLink is not installed
40
+ nvjitlink = None
41
+ elif _driver_ver > nvjitlink .version ():
42
+ # TODO: nvJitLink is not new enough, warn?
43
+ pass
44
+ if nvjitlink :
45
+ _nvjitlink = nvjitlink
46
+ _nvjitlink_input_types = {
47
+ "ptx" : _nvjitlink .InputType .PTX ,
48
+ "cubin" : _nvjitlink .InputType .CUBIN ,
49
+ "fatbin" : _nvjitlink .InputType .FATBIN ,
50
+ "ltoir" : _nvjitlink .InputType .LTOIR ,
51
+ "object" : _nvjitlink .InputType .OBJECT ,
52
+ }
53
+ else :
54
+ from cuda import cuda as _driver
55
+
56
+ _driver_input_types = {
57
+ "ptx" : _driver .CUjitInputType .CU_JIT_INPUT_PTX ,
58
+ "cubin" : _driver .CUjitInputType .CU_JIT_INPUT_CUBIN ,
59
+ "fatbin" : _driver .CUjitInputType .CU_JIT_INPUT_FATBINARY ,
60
+ "object" : _driver .CUjitInputType .CU_JIT_INPUT_OBJECT ,
61
+ }
62
+ _inited = True
12
63
13
64
14
65
@dataclass
@@ -146,7 +197,14 @@ class LinkerOptions:
146
197
no_cache : Optional [bool ] = None
147
198
148
199
def __post_init__ (self ):
200
+ _lazy_init ()
149
201
self .formatted_options = []
202
+ if _nvjitlink :
203
+ self ._init_nvjitlink ()
204
+ else :
205
+ self ._init_driver ()
206
+
207
+ def _init_nvjitlink (self ):
150
208
if self .arch is not None :
151
209
self .formatted_options .append (f"-arch={ self .arch } " )
152
210
if self .max_register_count is not None :
@@ -191,6 +249,67 @@ def __post_init__(self):
191
249
if self .no_cache is not None :
192
250
self .formatted_options .append ("-no-cache" )
193
251
252
+ def _init_driver (self ):
253
+ self .option_keys = []
254
+ # allocate 4 KiB each for info/error logs
255
+ size = 4194304
256
+ self .formatted_options .extend ((bytearray (size ), size , bytearray (size ), size ))
257
+ self .option_keys .extend (
258
+ (
259
+ _driver .CUjit_option .CU_JIT_INFO_LOG_BUFFER ,
260
+ _driver .CUjit_option .CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES ,
261
+ _driver .CUjit_option .CU_JIT_ERROR_LOG_BUFFER ,
262
+ _driver .CUjit_option .CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES ,
263
+ )
264
+ )
265
+
266
+ if self .arch is not None :
267
+ arch = self .arch .split ("_" )[- 1 ].upper ()
268
+ self .formatted_options .append (getattr (_driver .CUjit_target , f"CU_TARGET_COMPUTE_{ arch } " ))
269
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_TARGET )
270
+ # if self.max_register_count is not None:
271
+ # self.formatted_options.append(f"-maxrregcount={self.max_register_count}")
272
+ # if self.time is not None:
273
+ # self.formatted_options.append("-time")
274
+ if self .verbose is not None :
275
+ self .formatted_options .append (1 ) # ctypes.c_int32(1))
276
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_LOG_VERBOSE )
277
+ # if self.link_time_optimization is not None:
278
+ # self.formatted_options.append("-lto")
279
+ # if self.ptx is not None:
280
+ # self.formatted_options.append("-ptx")
281
+ # if self.optimization_level is not None:
282
+ # self.formatted_options.append(f"-O{self.optimization_level}")
283
+ # if self.debug is not None:
284
+ # self.formatted_options.append("-g")
285
+ # if self.lineinfo is not None:
286
+ # self.formatted_options.append("-lineinfo")
287
+ # if self.ftz is not None:
288
+ # self.formatted_options.append(f"-ftz={'true' if self.ftz else 'false'}")
289
+ # if self.prec_div is not None:
290
+ # self.formatted_options.append(f"-prec-div={'true' if self.prec_div else 'false'}")
291
+ # if self.prec_sqrt is not None:
292
+ # self.formatted_options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}")
293
+ # if self.fma is not None:
294
+ # self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}")
295
+ # if self.kernels_used is not None:
296
+ # for kernel in self.kernels_used:
297
+ # self.formatted_options.append(f"-kernels-used={kernel}")
298
+ # if self.variables_used is not None:
299
+ # for variable in self.variables_used:
300
+ # self.formatted_options.append(f"-variables-used={variable}")
301
+ # if self.optimize_unused_variables is not None:
302
+ # self.formatted_options.append("-optimize-unused-variables")
303
+ # if self.xptxas is not None:
304
+ # for opt in self.xptxas:
305
+ # self.formatted_options.append(f"-Xptxas={opt}")
306
+ # if self.split_compile is not None:
307
+ # self.formatted_options.append(f"-split-compile={self.split_compile}")
308
+ # if self.split_compile_extended is not None:
309
+ # self.formatted_options.append(f"-split-compile-extended={self.split_compile_extended}")
310
+ # if self.no_cache is not None:
311
+ # self.formatted_options.append("-no-cache")
312
+
194
313
195
314
class Linker :
196
315
"""
@@ -202,102 +321,116 @@ class Linker:
202
321
One or more ObjectCode objects to be linked.
203
322
options : LinkerOptions, optional
204
323
Options for the linker. If not provided, default options will be used.
205
-
206
- Attributes
207
- ----------
208
- _options : LinkerOptions
209
- The options used for the linker.
210
- _handle : handle
211
- The handle to the linker created by nvjitlink.
212
-
213
- Methods
214
- -------
215
- _add_code_object(object_code)
216
- Adds an object code to the linker.
217
- close()
218
- Closes the linker and releases resources.
219
324
"""
220
325
221
326
class _MembersNeededForFinalize :
222
- __slots__ = ("handle" ,)
327
+ __slots__ = ("handle" , "use_nvjitlink" )
223
328
224
- def __init__ (self , program_obj , handle ):
329
+ def __init__ (self , program_obj , handle , use_nvjitlink ):
225
330
self .handle = handle
331
+ self .use_nvjitlink = use_nvjitlink
226
332
weakref .finalize (program_obj , self .close )
227
333
228
334
def close (self ):
229
335
if self .handle is not None :
230
- nvjitlink .destroy (self .handle )
336
+ if self .use_nvjitlink :
337
+ _nvjitlink .destroy (self .handle )
338
+ else :
339
+ handle_return (_driver .cuLinkDestroy (self .handle ))
231
340
self .handle = None
232
341
233
342
__slots__ = ("__weakref__" , "_mnff" , "_options" )
234
343
235
344
def __init__ (self , * object_codes : ObjectCode , options : LinkerOptions = None ):
236
- self ._options = options = check_or_create_options (LinkerOptions , options , "Linker options" )
237
- self ._mnff = Linker ._MembersNeededForFinalize (
238
- self , nvjitlink .create (len (options .formatted_options ), options .formatted_options )
239
- )
240
-
241
345
if len (object_codes ) == 0 :
242
346
raise ValueError ("At least one ObjectCode object must be provided" )
243
347
348
+ self ._options = options = check_or_create_options (LinkerOptions , options , "Linker options" )
349
+ if _nvjitlink :
350
+ handle = _nvjitlink .create (len (options .formatted_options ), options .formatted_options )
351
+ use_nvjitlink = True
352
+ else :
353
+ handle = handle_return (
354
+ _driver .cuLinkCreate (len (options .formatted_options ), options .option_keys , options .formatted_options )
355
+ )
356
+ use_nvjitlink = False
357
+ self ._mnff = Linker ._MembersNeededForFinalize (self , handle , use_nvjitlink )
358
+
244
359
for code in object_codes :
245
360
assert isinstance (code , ObjectCode )
246
361
self ._add_code_object (code )
247
362
248
363
def _add_code_object (self , object_code : ObjectCode ):
249
364
data = object_code ._module
250
365
assert isinstance (data , bytes )
251
- nvjitlink .add_data (
252
- self ._mnff .handle ,
253
- self ._input_type_from_code_type (object_code ._code_type ),
254
- data ,
255
- len (data ),
256
- f"{ object_code ._handle } _{ object_code ._code_type } " ,
257
- )
258
-
259
- _get_linked_methods = {
260
- "cubin" : (nvjitlink .get_linked_cubin_size , nvjitlink .get_linked_cubin ),
261
- "ptx" : (nvjitlink .get_linked_ptx_size , nvjitlink .get_linked_ptx ),
262
- }
366
+ if _nvjitlink :
367
+ _nvjitlink .add_data (
368
+ self ._mnff .handle ,
369
+ self ._input_type_from_code_type (object_code ._code_type ),
370
+ data ,
371
+ len (data ),
372
+ f"{ object_code ._handle } _{ object_code ._code_type } " ,
373
+ )
374
+ else :
375
+ handle_return (
376
+ _driver .cuLinkAddData (
377
+ self ._mnff .handle ,
378
+ self ._input_type_from_code_type (object_code ._code_type ),
379
+ data ,
380
+ len (data ),
381
+ f"{ object_code ._handle } _{ object_code ._code_type } " .encode (),
382
+ 0 ,
383
+ None ,
384
+ None ,
385
+ )
386
+ )
263
387
264
388
def link (self , target_type ) -> ObjectCode :
265
- nvjitlink .complete (self ._mnff .handle )
266
- get_linked = self ._get_linked_methods .get (target_type )
267
- if get_linked is None :
389
+ if target_type not in ("cubin" , "ptx" ):
268
390
raise ValueError (f"Unsupported target type: { target_type } " )
391
+ if _nvjitlink :
392
+ _nvjitlink .complete (self ._mnff .handle )
393
+ if target_type == "cubin" :
394
+ get_size = _nvjitlink .get_linked_cubin_size
395
+ get_code = _nvjitlink .get_linked_cubin
396
+ else :
397
+ get_size = _nvjitlink .get_linked_ptx_size
398
+ get_code = _nvjitlink .get_linked_ptx
269
399
270
- get_size , get_code = get_linked
271
- size = get_size (self ._mnff .handle )
272
- code = bytearray (size )
273
- get_code (self ._mnff .handle , code )
400
+ size = get_size (self ._mnff .handle )
401
+ code = bytearray (size )
402
+ get_code (self ._mnff .handle , code )
403
+ else :
404
+ addr , size = handle_return (_driver .cuLinkComplete (self ._mnff .handle ))
405
+ code = (ctypes .c_char * size ).from_address (addr )
274
406
275
407
return ObjectCode (bytes (code ), target_type )
276
408
277
409
def get_error_log (self ) -> str :
278
- log_size = nvjitlink .get_error_log_size (self ._mnff .handle )
279
- log = bytearray (log_size )
280
- nvjitlink .get_error_log (self ._mnff .handle , log )
410
+ if _nvjitlink :
411
+ log_size = _nvjitlink .get_error_log_size (self ._mnff .handle )
412
+ log = bytearray (log_size )
413
+ _nvjitlink .get_error_log (self ._mnff .handle , log )
414
+ else :
415
+ log = self ._options .formatted_options [2 ]
281
416
return log .decode ()
282
417
283
418
def get_info_log (self ) -> str :
284
- log_size = nvjitlink .get_info_log_size (self ._mnff .handle )
285
- log = bytearray (log_size )
286
- nvjitlink .get_info_log (self ._mnff .handle , log )
419
+ if _nvjitlink :
420
+ log_size = _nvjitlink .get_info_log_size (self ._mnff .handle )
421
+ log = bytearray (log_size )
422
+ _nvjitlink .get_info_log (self ._mnff .handle , log )
423
+ else :
424
+ log = self ._options .formatted_options [0 ]
287
425
return log .decode ()
288
426
289
- _input_types = {
290
- "ptx" : nvjitlink .InputType .PTX ,
291
- "cubin" : nvjitlink .InputType .CUBIN ,
292
- "fatbin" : nvjitlink .InputType .FATBIN ,
293
- "ltoir" : nvjitlink .InputType .LTOIR ,
294
- "object" : nvjitlink .InputType .OBJECT ,
295
- }
296
-
297
- def _input_type_from_code_type (self , code_type : str ) -> nvjitlink .InputType :
427
+ def _input_type_from_code_type (self , code_type : str ):
298
428
# this list is based on the supported values for code_type in the ObjectCode class definition.
299
- # nvjitlink supports other options for input type
300
- input_type = self ._input_types .get (code_type )
429
+ # nvJitLink/driver support other options for input type
430
+ if _nvjitlink :
431
+ input_type = _nvjitlink_input_types .get (code_type )
432
+ else :
433
+ input_type = _driver_input_types .get (code_type )
301
434
302
435
if input_type is None :
303
436
raise ValueError (f"Unknown code_type associated with ObjectCode: { code_type } " )
0 commit comments