Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 6b5c521

Browse files
authored
Refactor DeviceScopes out into a separate structure. (#370)
* Refactor DeviceScopes out into a separate structure. * class -> struct
1 parent 5e24305 commit 6b5c521

File tree

1 file changed

+33
-28
lines changed

1 file changed

+33
-28
lines changed

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ internal extension _ExecutionContext {
10011001
/// withDevice call on the call stack or the presence of an immediately enclosing
10021002
/// `withDefaultDevice(perform)` call.
10031003
var currentDeviceName: String? {
1004-
return _ThreadLocalState.local._currentDevice
1004+
return _ThreadLocalState.local.deviceScopes._currentDevice
10051005
}
10061006

10071007
/// See documentation for the top-level `withDevice(_:_:perform)`.
@@ -1029,17 +1029,17 @@ internal extension _ExecutionContext {
10291029
guard deviceNames.contains(name) else {
10301030
fatalError("Device \(name) not found")
10311031
}
1032-
_ThreadLocalState.local.pushDevice(name)
1032+
_ThreadLocalState.local.deviceScopes.pushDevice(name)
10331033
let result = try body()
1034-
_ThreadLocalState.local.popDevice()
1034+
_ThreadLocalState.local.deviceScopes.popDevice()
10351035
return result
10361036
}
10371037

10381038
/// See documentation for the top-level `withDefaultDevice(perform)`.
10391039
func withDefaultDevice<R>(perform body: () throws -> R) rethrows -> R {
1040-
_ThreadLocalState.local.pushDevice(nil)
1040+
_ThreadLocalState.local.deviceScopes.pushDevice(nil)
10411041
let result = try body()
1042-
_ThreadLocalState.local.popDevice()
1042+
_ThreadLocalState.local.deviceScopes.popDevice()
10431043
return result
10441044
}
10451045
}
@@ -1233,16 +1233,9 @@ fileprivate func setAttrShapeList(
12331233
}
12341234
}
12351235

1236-
/// Stack of devices that models nested calls to withDevice/withDefaultDevice. Devices are
1237-
/// represented by their names in TensorFlow notation. See documentation for
1238-
/// `withDevice(named:perform:)` to learn about device names.
1239-
///
1240-
/// All TensorFlow operations will be put on the topmost device on the stack. When the stack is
1241-
/// empty or the topmost device is `nil`, that allows TensorFlow to place operations on any device
1242-
/// that it sees fit.
1243-
@usableFromInline
1236+
/// A class to keep around thread local state.
12441237
class _ThreadLocalState {
1245-
var deviceScopes: [String?] = []
1238+
var deviceScopes = DeviceScopes()
12461239

12471240
private static let key: pthread_key_t = {
12481241
var key = pthread_key_t()
@@ -1256,20 +1249,6 @@ class _ThreadLocalState {
12561249
return key
12571250
}()
12581251

1259-
var _currentDevice: String? {
1260-
return deviceScopes.last ?? nil
1261-
}
1262-
1263-
@usableFromInline
1264-
func pushDevice(_ device: String?) {
1265-
deviceScopes.append(device)
1266-
}
1267-
1268-
@usableFromInline
1269-
func popDevice() {
1270-
internalConsistencyCheck(deviceScopes.popLast() != nil)
1271-
}
1272-
12731252
@usableFromInline
12741253
static var local: _ThreadLocalState {
12751254
if let state = pthread_getspecific(key) {
@@ -1281,6 +1260,32 @@ class _ThreadLocalState {
12811260
}
12821261
}
12831262

1263+
/// Stack of devices that models nested calls to withDevice/withDefaultDevice. Devices are
1264+
/// represented by their names in TensorFlow notation. See documentation for
1265+
/// `withDevice(named:perform:)` to learn about device names.
1266+
///
1267+
/// All TensorFlow operations will be put on the topmost device on the stack. When the stack is
1268+
/// empty or the topmost device is `nil`, that allows TensorFlow to place operations on any device
1269+
/// that it sees fit.
1270+
@usableFromInline
1271+
struct DeviceScopes {
1272+
var deviceStack: [String?] = []
1273+
1274+
var _currentDevice: String? {
1275+
return deviceStack.last ?? nil
1276+
}
1277+
1278+
@usableFromInline
1279+
mutating func pushDevice(_ device: String?) {
1280+
deviceStack.append(device)
1281+
}
1282+
1283+
@usableFromInline
1284+
mutating func popDevice() {
1285+
internalConsistencyCheck(deviceStack.popLast() != nil)
1286+
}
1287+
}
1288+
12841289
@usableFromInline
12851290
@_cdecl("_swift_tfc_OpSetDeviceFromScope")
12861291
func _TFCOpSetDeviceFromScope(_ op: CTFEOp, _ status: CTFStatus) {

0 commit comments

Comments
 (0)