@@ -29,6 +29,7 @@ def _lazy_init():
29
29
_driver_ver = handle_return (cuda .cuDriverGetVersion ())
30
30
_driver_ver = (_driver_ver // 1000 , (_driver_ver % 1000 ) // 10 )
31
31
try :
32
+ raise ImportError
32
33
from cuda .bindings import nvjitlink
33
34
from cuda .bindings ._internal import nvjitlink as inner_nvjitlink
34
35
except ImportError :
@@ -267,48 +268,66 @@ def _init_driver(self):
267
268
arch = self .arch .split ("_" )[- 1 ].upper ()
268
269
self .formatted_options .append (getattr (_driver .CUjit_target , f"CU_TARGET_COMPUTE_{ arch } " ))
269
270
self .option_keys .append (_driver .CUjit_option .CU_JIT_TARGET )
270
- # if self.max_register_count is not None:
271
- # self.formatted_options.append(f"-maxrregcount={self.max_register_count}")
272
- # if self.time is not None:
273
- # self.formatted_options.append("-time")
271
+ if self .max_register_count is not None :
272
+ self .formatted_options .append (self .max_register_count )
273
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_MAX_REGISTERS )
274
+ if self .time is not None :
275
+ self .formatted_options .append (1 ) # ctypes.c_int32(1)
276
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_WALL_TIME )
274
277
if self .verbose is not None :
275
- self .formatted_options .append (1 ) # ctypes.c_int32(1))
278
+ self .formatted_options .append (1 ) # ctypes.c_int32(1)
276
279
self .option_keys .append (_driver .CUjit_option .CU_JIT_LOG_VERBOSE )
277
- # if self.link_time_optimization is not None:
278
- # self.formatted_options.append("-lto")
279
- # if self.ptx is not None:
280
- # self.formatted_options.append("-ptx")
281
- # if self.optimization_level is not None:
282
- # self.formatted_options.append(f"-O{self.optimization_level}")
283
- # if self.debug is not None:
284
- # self.formatted_options.append("-g")
285
- # if self.lineinfo is not None:
286
- # self.formatted_options.append("-lineinfo")
287
- # if self.ftz is not None:
288
- # self.formatted_options.append(f"-ftz={'true' if self.ftz else 'false'}")
289
- # if self.prec_div is not None:
290
- # self.formatted_options.append(f"-prec-div={'true' if self.prec_div else 'false'}")
291
- # if self.prec_sqrt is not None:
292
- # self.formatted_options.append(f"-prec-sqrt={'true' if self.prec_sqrt else 'false'}")
293
- # if self.fma is not None:
294
- # self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}")
295
- # if self.kernels_used is not None:
296
- # for kernel in self.kernels_used:
297
- # self.formatted_options.append(f"-kernels-used={kernel}")
298
- # if self.variables_used is not None:
299
- # for variable in self.variables_used:
300
- # self.formatted_options.append(f"-variables-used={variable}")
301
- # if self.optimize_unused_variables is not None:
302
- # self.formatted_options.append("-optimize-unused-variables")
303
- # if self.xptxas is not None:
304
- # for opt in self.xptxas:
305
- # self.formatted_options.append(f"-Xptxas={opt}")
306
- # if self.split_compile is not None:
307
- # self.formatted_options.append(f"-split-compile={self.split_compile}")
308
- # if self.split_compile_extended is not None:
309
- # self.formatted_options.append(f"-split-compile-extended={self.split_compile_extended}")
310
- # if self.no_cache is not None:
311
- # self.formatted_options.append("-no-cache")
280
+ if self .link_time_optimization is not None :
281
+ self .formatted_options .append (1 ) # ctypes.c_int32(1)
282
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_LTO )
283
+ if self .ptx is not None :
284
+ self .formatted_options .append (1 ) # ctypes.c_int32(1)
285
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_GENERATE_LINE_INFO )
286
+ if self .optimization_level is not None :
287
+ self .formatted_options .append (self .optimization_level )
288
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_OPTIMIZATION_LEVEL )
289
+ if self .debug is not None :
290
+ self .formatted_options .append (1 ) # ctypes.c_int32(1)
291
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_GENERATE_DEBUG_INFO )
292
+ if self .lineinfo is not None :
293
+ self .formatted_options .append (1 ) # ctypes.c_int32(1)
294
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_GENERATE_LINE_INFO )
295
+ if self .ftz is not None :
296
+ self .formatted_options .append (1 if self .ftz else 0 )
297
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_FTZ )
298
+ if self .prec_div is not None :
299
+ self .formatted_options .append (1 if self .prec_div else 0 )
300
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_PREC_DIV )
301
+ if self .prec_sqrt is not None :
302
+ self .formatted_options .append (1 if self .prec_sqrt else 0 )
303
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_PREC_SQRT )
304
+ if self .fma is not None :
305
+ self .formatted_options .append (1 if self .fma else 0 )
306
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_FMA )
307
+ if self .kernels_used is not None :
308
+ for kernel in self .kernels_used :
309
+ self .formatted_options .append (kernel )
310
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_REFERENCED_KERNEL_NAMES )
311
+ if self .variables_used is not None :
312
+ for variable in self .variables_used :
313
+ self .formatted_options .append (variable )
314
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_REFERENCED_VARIABLE_NAMES )
315
+ if self .optimize_unused_variables is not None :
316
+ self .formatted_options .append (1 ) # ctypes.c_int32(1)
317
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES )
318
+ if self .xptxas is not None :
319
+ for opt in self .xptxas :
320
+ self .formatted_options .append (opt )
321
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_FAST_COMPILE )
322
+ if self .split_compile is not None :
323
+ self .formatted_options .append (self .split_compile )
324
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_THREADS_PER_BLOCK )
325
+ if self .split_compile_extended is not None :
326
+ self .formatted_options .append (self .split_compile_extended )
327
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_MIN_CTA_PER_SM )
328
+ if self .no_cache is not None :
329
+ self .formatted_options .append (1 ) # ctypes.c_int32(1)
330
+ self .option_keys .append (_driver .CUjit_option .CU_JIT_CACHE_MODE )
312
331
313
332
314
333
class Linker :
@@ -427,10 +446,7 @@ def get_info_log(self) -> str:
427
446
def _input_type_from_code_type (self , code_type : str ):
428
447
# this list is based on the supported values for code_type in the ObjectCode class definition.
429
448
# nvJitLink/driver support other options for input type
430
- if _nvjitlink :
431
- input_type = _nvjitlink_input_types .get (code_type )
432
- else :
433
- input_type = _driver_input_types .get (code_type )
449
+ input_type = _nvjitlink_input_types .get (code_type ) if _nvjitlink else _driver_input_types .get (code_type )
434
450
435
451
if input_type is None :
436
452
raise ValueError (f"Unknown code_type associated with ObjectCode: { code_type } " )
0 commit comments