Skip to content

Commit c9fac0b

Browse files
committed
minor perf opt: try-except + skip assert
1 parent 3985435 commit c9fac0b

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -964,19 +964,19 @@ def __new__(cls, device_id=None):
964964
ctx = handle_return(driver.cuCtxGetCurrent())
965965
assert int(ctx) == 0
966966
device_id = 0 # cudart behavior
967-
assert isinstance(device_id, int), f"{device_id=}"
968967
else:
969968
total = handle_return(driver.cuDeviceGetCount())
970969
if not isinstance(device_id, int) or not (0 <= device_id < total):
971970
raise ValueError(f"device_id must be within [0, {total}), got {device_id}")
972971

973972
# ensure Device is singleton
974-
if not hasattr(_tls, "devices"):
973+
try:
974+
devices = _tls.devices
975+
except AttributeError:
975976
total = handle_return(driver.cuDeviceGetCount())
976-
_tls.devices = []
977+
devices = _tls.devices = []
977978
for dev_id in range(total):
978979
dev = super().__new__(cls)
979-
980980
dev._id = dev_id
981981
# If the device is in TCC mode, or does not support memory pools for some other reason,
982982
# use the SynchronousMemoryResource which does not use memory pools.
@@ -990,12 +990,12 @@ def __new__(cls, device_id=None):
990990
dev._mr = _DefaultAsyncMempool(dev_id)
991991
else:
992992
dev._mr = _SynchronousMemoryResource(dev_id)
993+
993994
dev._has_inited = False
994995
dev._properties = None
996+
devices.append(dev)
995997

996-
_tls.devices.append(dev)
997-
998-
return _tls.devices[device_id]
998+
return devices[device_id]
999999

10001000
def _check_context_initialized(self, *args, **kwargs):
10011001
if not self._has_inited:

0 commit comments

Comments
 (0)