Skip to content

Switch to use CUDA driver APIs in Device constructor #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 7, 2025
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions cuda_core/cuda/core/experimental/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,34 +957,42 @@ def __new__(cls, device_id=None):

# important: creating a Device instance does not initialize the GPU!
if device_id is None:
device_id = handle_return(runtime.cudaGetDevice())
assert_type(device_id, int)
err, dev = driver.cuCtxGetDevice()
if err == 0:
device_id = int(dev)
else:
ctx = handle_return(driver.cuCtxGetCurrent())
assert int(ctx) == 0
device_id = 0 # cudart behavior
assert isinstance(device_id, int), f"{device_id=}"
else:
total = handle_return(runtime.cudaGetDeviceCount())
assert_type(device_id, int)
if not (0 <= device_id < total):
total = handle_return(driver.cuDeviceGetCount())
if not isinstance(device_id, int) or not (0 <= device_id < total):
raise ValueError(f"device_id must be within [0, {total}), got {device_id}")

# ensure Device is singleton
if not hasattr(_tls, "devices"):
total = handle_return(runtime.cudaGetDeviceCount())
total = handle_return(driver.cuDeviceGetCount())
_tls.devices = []
for dev_id in range(total):
dev = super().__new__(cls)

dev._id = dev_id
# If the device is in TCC mode, or does not support memory pools for some other reason,
# use the SynchronousMemoryResource which does not use memory pools.
if (
handle_return(
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id
)
)
) == 1:
dev._mr = _DefaultAsyncMempool(dev_id)
else:
dev._mr = _SynchronousMemoryResource(dev_id)

dev._has_inited = False
dev._properties = None

_tls.devices.append(dev)

return _tls.devices[device_id]
Expand Down
Loading