Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Refactor DeviceScopes out into a separate structure. #370

Merged
merged 2 commits into from
Jul 17, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 33 additions & 28 deletions Sources/TensorFlow/Core/Runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ internal extension _ExecutionContext {
/// withDevice call on the call stack or the presence of an immediately enclosing
/// `withDefaultDevice(perform)` call.
var currentDeviceName: String? {
return _ThreadLocalState.local._currentDevice
return _ThreadLocalState.local.deviceScopes._currentDevice
}

/// See documentation for the top-level `withDevice(_:_:perform)`.
Expand Down Expand Up @@ -1029,17 +1029,17 @@ internal extension _ExecutionContext {
guard deviceNames.contains(name) else {
fatalError("Device \(name) not found")
}
_ThreadLocalState.local.pushDevice(name)
_ThreadLocalState.local.deviceScopes.pushDevice(name)
let result = try body()
_ThreadLocalState.local.popDevice()
_ThreadLocalState.local.deviceScopes.popDevice()
return result
}

/// See documentation for the top-level `withDefaultDevice(perform)`.
func withDefaultDevice<R>(perform body: () throws -> R) rethrows -> R {
_ThreadLocalState.local.pushDevice(nil)
_ThreadLocalState.local.deviceScopes.pushDevice(nil)
let result = try body()
_ThreadLocalState.local.popDevice()
_ThreadLocalState.local.deviceScopes.popDevice()
return result
}
}
Expand Down Expand Up @@ -1233,16 +1233,9 @@ fileprivate func setAttrShapeList(
}
}

/// Stack of devices that models nested calls to withDevice/withDefaultDevice. Devices are
/// represented by their names in TensorFlow notation. See documentation for
/// `withDevice(named:perform:)` to learn about device names.
///
/// All TensorFlow operations will be put on the topmost device on the stack. When the stack is
/// empty or the topmost device is `nil`, that allows TensorFlow to place operations on any device
/// that it sees fit.
@usableFromInline
/// A class to keep around thread local state.
class _ThreadLocalState {
var deviceScopes: [String?] = []
var deviceScopes = DeviceScopes()

private static let key: pthread_key_t = {
var key = pthread_key_t()
Expand All @@ -1256,20 +1249,6 @@ class _ThreadLocalState {
return key
}()

var _currentDevice: String? {
return deviceScopes.last ?? nil
}

@usableFromInline
func pushDevice(_ device: String?) {
deviceScopes.append(device)
}

@usableFromInline
func popDevice() {
internalConsistencyCheck(deviceScopes.popLast() != nil)
}

@usableFromInline
static var local: _ThreadLocalState {
if let state = pthread_getspecific(key) {
Expand All @@ -1281,6 +1260,32 @@ class _ThreadLocalState {
}
}

/// Stack of devices that models nested calls to withDevice/withDefaultDevice. Devices are
/// represented by their names in TensorFlow notation. See documentation for
/// `withDevice(named:perform:)` to learn about device names.
///
/// All TensorFlow operations will be put on the topmost device on the stack. When the stack is
/// empty or the topmost device is `nil`, that allows TensorFlow to place operations on any device
/// that it sees fit.
@usableFromInline
struct DeviceScopes {
var deviceStack: [String?] = []

var _currentDevice: String? {
return deviceStack.last ?? nil
}

@usableFromInline
mutating func pushDevice(_ device: String?) {
deviceStack.append(device)
}

@usableFromInline
mutating func popDevice() {
internalConsistencyCheck(deviceStack.popLast() != nil)
}
}

@usableFromInline
@_cdecl("_swift_tfc_OpSetDeviceFromScope")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be deleted now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still used in EagerExecution.swift. However, those use cases can be rewritten. Will send out a separate PR for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you were only referring to _cdecl. Will send out a PR cleaning this up and a few other uses.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it'd be nice to have EagerExecution.swift use C API directly.

func _TFCOpSetDeviceFromScope(_ op: CTFEOp, _ status: CTFStatus) {
Expand Down