@@ -1001,7 +1001,7 @@ internal extension _ExecutionContext {
1001
1001
/// withDevice call on the call stack or the presence of an immediately enclosing
1002
1002
/// `withDefaultDevice(perform)` call.
1003
1003
var currentDeviceName : String ? {
1004
- return _ThreadLocalState. local. _currentDevice
1004
+ return _ThreadLocalState. local. deviceScopes . _currentDevice
1005
1005
}
1006
1006
1007
1007
/// See documentation for the top-level `withDevice(_:_:perform)`.
@@ -1029,17 +1029,17 @@ internal extension _ExecutionContext {
1029
1029
guard deviceNames. contains ( name) else {
1030
1030
fatalError ( " Device \( name) not found " )
1031
1031
}
1032
- _ThreadLocalState. local. pushDevice ( name)
1032
+ _ThreadLocalState. local. deviceScopes . pushDevice ( name)
1033
1033
let result = try body ( )
1034
- _ThreadLocalState. local. popDevice ( )
1034
+ _ThreadLocalState. local. deviceScopes . popDevice ( )
1035
1035
return result
1036
1036
}
1037
1037
1038
1038
/// See documentation for the top-level `withDefaultDevice(perform)`.
1039
1039
func withDefaultDevice< R> ( perform body: ( ) throws -> R ) rethrows -> R {
1040
- _ThreadLocalState. local. pushDevice ( nil )
1040
+ _ThreadLocalState. local. deviceScopes . pushDevice ( nil )
1041
1041
let result = try body ( )
1042
- _ThreadLocalState. local. popDevice ( )
1042
+ _ThreadLocalState. local. deviceScopes . popDevice ( )
1043
1043
return result
1044
1044
}
1045
1045
}
@@ -1233,16 +1233,9 @@ fileprivate func setAttrShapeList(
1233
1233
}
1234
1234
}
1235
1235
1236
- /// Stack of devices that models nested calls to withDevice/withDefaultDevice. Devices are
1237
- /// represented by their names in TensorFlow notation. See documentation for
1238
- /// `withDevice(named:perform:)` to learn about device names.
1239
- ///
1240
- /// All TensorFlow operations will be put on the topmost device on the stack. When the stack is
1241
- /// empty or the topmost device is `nil`, that allows TensorFlow to place operations on any device
1242
- /// that it sees fit.
1243
- @usableFromInline
1236
+ /// A class to keep around thread local state.
1244
1237
class _ThreadLocalState {
1245
- var deviceScopes : [ String ? ] = [ ]
1238
+ var deviceScopes = DeviceScopes ( )
1246
1239
1247
1240
private static let key : pthread_key_t = {
1248
1241
var key = pthread_key_t ( )
@@ -1256,20 +1249,6 @@ class _ThreadLocalState {
1256
1249
return key
1257
1250
} ( )
1258
1251
1259
- var _currentDevice : String ? {
1260
- return deviceScopes. last ?? nil
1261
- }
1262
-
1263
- @usableFromInline
1264
- func pushDevice( _ device: String ? ) {
1265
- deviceScopes. append ( device)
1266
- }
1267
-
1268
- @usableFromInline
1269
- func popDevice( ) {
1270
- internalConsistencyCheck ( deviceScopes. popLast ( ) != nil )
1271
- }
1272
-
1273
1252
@usableFromInline
1274
1253
static var local : _ThreadLocalState {
1275
1254
if let state = pthread_getspecific ( key) {
@@ -1281,6 +1260,32 @@ class _ThreadLocalState {
1281
1260
}
1282
1261
}
1283
1262
1263
+ /// Stack of devices that models nested calls to withDevice/withDefaultDevice. Devices are
1264
+ /// represented by their names in TensorFlow notation. See documentation for
1265
+ /// `withDevice(named:perform:)` to learn about device names.
1266
+ ///
1267
+ /// All TensorFlow operations will be put on the topmost device on the stack. When the stack is
1268
+ /// empty or the topmost device is `nil`, that allows TensorFlow to place operations on any device
1269
+ /// that it sees fit.
1270
+ @usableFromInline
1271
+ struct DeviceScopes {
1272
+ var deviceStack : [ String ? ] = [ ]
1273
+
1274
+ var _currentDevice : String ? {
1275
+ return deviceStack. last ?? nil
1276
+ }
1277
+
1278
+ @usableFromInline
1279
+ mutating func pushDevice( _ device: String ? ) {
1280
+ deviceStack. append ( device)
1281
+ }
1282
+
1283
+ @usableFromInline
1284
+ mutating func popDevice( ) {
1285
+ internalConsistencyCheck ( deviceStack. popLast ( ) != nil )
1286
+ }
1287
+ }
1288
+
1284
1289
@usableFromInline
1285
1290
@_cdecl ( " _swift_tfc_OpSetDeviceFromScope " )
1286
1291
func _TFCOpSetDeviceFromScope( _ op: CTFEOp , _ status: CTFStatus ) {
0 commit comments