Skip to content

Commit 440eabd

Browse files
authored
Improve perf of accessing dev.compute_capability (#459)
* cache cc to speed it up * avoid silly, redundant lock
1 parent 976246a commit 440eabd

File tree

2 files changed

+34
-31
lines changed

2 files changed

+34
-31
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from cuda.core.experimental._utils import ComputeCapability, CUDAError, driver, handle_return, precondition, runtime
1212

1313
_tls = threading.local()
14-
_tls_lock = threading.Lock()
14+
_lock = threading.Lock()
15+
_is_cuInit = False
1516

1617

1718
class DeviceProperties:
@@ -938,6 +939,12 @@ class Device:
938939
__slots__ = ("_id", "_mr", "_has_inited", "_properties")
939940

940941
def __new__(cls, device_id=None):
942+
global _is_cuInit
943+
if _is_cuInit is False:
944+
with _lock:
945+
handle_return(driver.cuInit(0))
946+
_is_cuInit = True
947+
941948
# important: creating a Device instance does not initialize the GPU!
942949
if device_id is None:
943950
device_id = handle_return(runtime.cudaGetDevice())
@@ -948,27 +955,26 @@ def __new__(cls, device_id=None):
948955
raise ValueError(f"device_id must be within [0, {total}), got {device_id}")
949956

950957
# ensure Device is singleton
951-
with _tls_lock:
952-
if not hasattr(_tls, "devices"):
953-
total = handle_return(runtime.cudaGetDeviceCount())
954-
_tls.devices = []
955-
for dev_id in range(total):
956-
dev = super().__new__(cls)
957-
dev._id = dev_id
958-
# If the device is in TCC mode, or does not support memory pools for some other reason,
959-
# use the SynchronousMemoryResource which does not use memory pools.
960-
if (
961-
handle_return(
962-
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
963-
)
964-
) == 1:
965-
dev._mr = _DefaultAsyncMempool(dev_id)
966-
else:
967-
dev._mr = _SynchronousMemoryResource(dev_id)
968-
969-
dev._has_inited = False
970-
dev._properties = None
971-
_tls.devices.append(dev)
958+
if not hasattr(_tls, "devices"):
959+
total = handle_return(runtime.cudaGetDeviceCount())
960+
_tls.devices = []
961+
for dev_id in range(total):
962+
dev = super().__new__(cls)
963+
dev._id = dev_id
964+
# If the device is in TCC mode, or does not support memory pools for some other reason,
965+
# use the SynchronousMemoryResource which does not use memory pools.
966+
if (
967+
handle_return(
968+
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
969+
)
970+
) == 1:
971+
dev._mr = _DefaultAsyncMempool(dev_id)
972+
else:
973+
dev._mr = _SynchronousMemoryResource(dev_id)
974+
975+
dev._has_inited = False
976+
dev._properties = None
977+
_tls.devices.append(dev)
972978

973979
return _tls.devices[device_id]
974980

@@ -1029,13 +1035,11 @@ def properties(self) -> DeviceProperties:
10291035
@property
10301036
def compute_capability(self) -> ComputeCapability:
10311037
"""Return a named tuple with 2 fields: major and minor."""
1032-
major = handle_return(
1033-
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, self._id)
1034-
)
1035-
minor = handle_return(
1036-
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, self._id)
1037-
)
1038-
return ComputeCapability(major, minor)
1038+
if "compute_capability" in self.properties._cache:
1039+
return self.properties._cache["compute_capability"]
1040+
cc = ComputeCapability(self.properties.compute_capability_major, self.properties.compute_capability_minor)
1041+
self.properties._cache["compute_capability"] = cc
1042+
return cc
10391043

10401044
@property
10411045
@precondition(_check_context_initialized)

cuda_core/tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ def _device_unset_current():
4141
return
4242
handle_return(driver.cuCtxPopCurrent())
4343
if hasattr(_device._tls, "devices"):
44-
with _device._tls_lock:
45-
del _device._tls.devices
44+
del _device._tls.devices
4645

4746

4847
@pytest.fixture(scope="function")

0 commit comments

Comments
 (0)