Skip to content

Commit b2495d3

Browse files
committed
squash
1 parent 023c4bf commit b2495d3

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

cuda_core/cuda/core/experimental/_program.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from cuda.core.experimental._utils import (
1212
_handle_boolean_option,
1313
check_or_create_options,
14+
driver,
1415
handle_return,
1516
is_nested_sequence,
1617
is_sequence,
@@ -413,6 +414,21 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
413414
raise TypeError
414415
# TODO: support pre-loaded headers & include names
415416
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
417+
418+
supported_archs = handle_return(nvrtc.nvrtcGetSupportedArchs())
419+
420+
if options is not None:
421+
arch_not_supported = options.arch is not None and options.arch not in supported_archs
422+
default_arch_not_supported = (
423+
options.arch is None
424+
and 10 * Device().compute_capability[0] + Device().compute_capability[1] not in supported_archs
425+
)
426+
427+
if arch_not_supported or default_arch_not_supported:
428+
raise ValueError(
429+
f"The provided arch, or default arch (that of the current device) "
430+
f"is not supported by the current backend. Supported architectures: {supported_archs}"
431+
)
416432
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
417433
self._backend = "nvrtc"
418434
else:
@@ -448,6 +464,12 @@ def compile(self, target_type, name_expressions=(), logs=None):
448464
raise NotImplementedError
449465

450466
if self._backend == "nvrtc":
467+
version = handle_return(nvrtc.nvrtcVersion())
468+
if handle_return(driver.cuDriverGetVersion()) > version[0] * 1000 + version[1] * 10:
469+
raise RuntimeError(
470+
"The CUDA driver version is newer than the NVRTC version. "
471+
"Please update your NVRTC library to match the CUDA driver version."
472+
)
451473
if name_expressions:
452474
for n in name_expressions:
453475
handle_return(nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), handle=self._mnff.handle)

0 commit comments

Comments
 (0)