@@ -964,19 +964,19 @@ def __new__(cls, device_id=None):
964
964
ctx = handle_return (driver .cuCtxGetCurrent ())
965
965
assert int (ctx ) == 0
966
966
device_id = 0 # cudart behavior
967
- assert isinstance (device_id , int ), f"{ device_id = } "
968
967
else :
969
968
total = handle_return (driver .cuDeviceGetCount ())
970
969
if not isinstance (device_id , int ) or not (0 <= device_id < total ):
971
970
raise ValueError (f"device_id must be within [0, { total } ), got { device_id } " )
972
971
973
972
# ensure Device is singleton
974
- if not hasattr (_tls , "devices" ):
973
+ try :
974
+ devices = _tls .devices
975
+ except AttributeError :
975
976
total = handle_return (driver .cuDeviceGetCount ())
976
- _tls .devices = []
977
+ devices = _tls .devices = []
977
978
for dev_id in range (total ):
978
979
dev = super ().__new__ (cls )
979
-
980
980
dev ._id = dev_id
981
981
# If the device is in TCC mode, or does not support memory pools for some other reason,
982
982
# use the SynchronousMemoryResource which does not use memory pools.
@@ -990,12 +990,12 @@ def __new__(cls, device_id=None):
990
990
dev ._mr = _DefaultAsyncMempool (dev_id )
991
991
else :
992
992
dev ._mr = _SynchronousMemoryResource (dev_id )
993
+
993
994
dev ._has_inited = False
994
995
dev ._properties = None
996
+ devices .append (dev )
995
997
996
- _tls .devices .append (dev )
997
-
998
- return _tls .devices [device_id ]
998
+ return devices [device_id ]
999
999
1000
1000
def _check_context_initialized (self , * args , ** kwargs ):
1001
1001
if not self ._has_inited :
0 commit comments