|
11 | 11 | from cuda.core.experimental._utils import (
|
12 | 12 | _handle_boolean_option,
|
13 | 13 | check_or_create_options,
|
| 14 | + driver, |
14 | 15 | handle_return,
|
15 | 16 | is_nested_sequence,
|
16 | 17 | is_sequence,
|
@@ -413,6 +414,21 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
|
413 | 414 | raise TypeError
|
414 | 415 | # TODO: support pre-loaded headers & include names
|
415 | 416 | # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
|
| 417 | + |
| 418 | + supported_archs = handle_return(nvrtc.nvrtcGetSupportedArchs()) |
| 419 | + |
| 420 | + if options is not None: |
| 421 | + arch_not_supported = options.arch is not None and options.arch not in supported_archs |
| 422 | + default_arch_not_supported = ( |
| 423 | + options.arch is None |
| 424 | + and 10 * Device().compute_capability[0] + Device().compute_capability[1] not in supported_archs |
| 425 | + ) |
| 426 | + |
| 427 | + if arch_not_supported or default_arch_not_supported: |
| 428 | + raise ValueError( |
| 429 | + f"The provided arch, or default arch (that of the current device) " |
| 430 | + f"is not supported by the current backend. Supported architectures: {supported_archs}" |
| 431 | + ) |
416 | 432 | self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
|
417 | 433 | self._backend = "nvrtc"
|
418 | 434 | else:
|
@@ -448,6 +464,12 @@ def compile(self, target_type, name_expressions=(), logs=None):
|
448 | 464 | raise NotImplementedError
|
449 | 465 |
|
450 | 466 | if self._backend == "nvrtc":
|
| 467 | + version = handle_return(nvrtc.nvrtcVersion()) |
| 468 | + if handle_return(driver.cuDriverGetVersion()) > version[0] * 1000 + version[1] * 10: |
| 469 | + raise RuntimeError( |
| 470 | + "The CUDA driver version is newer than the NVRTC version. " |
| 471 | + "Please update your NVRTC library to match the CUDA driver version." |
| 472 | + ) |
451 | 473 | if name_expressions:
|
452 | 474 | for n in name_expressions:
|
453 | 475 | handle_return(nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), handle=self._mnff.handle)
|
|
0 commit comments