Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit faf540a

Browse files
bgogulrxwei
authored andcommitted
Cleanup: use .local for accessing thread local state. (#416)
1 parent 22f5e43 commit faf540a

File tree

3 files changed

+14
-25
lines changed

3 files changed

+14
-25
lines changed

Sources/TensorFlow/Core/LazyTensorContext.swift

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,11 @@ class LazyTensorOperationsTracker {
6161
}
6262

6363
struct LazyTensorContext {
64-
private var operationsTracker = LazyTensorOperationsTracker()
65-
private var isShapeTrackingEnabled = true
64+
var operationsTracker = LazyTensorOperationsTracker()
65+
var isShapeTrackingEnabled = true
6666

67-
static private var local: LazyTensorContext {
67+
static var local: LazyTensorContext {
6868
_read { yield _ThreadLocalState.local.lazyTensorContext }
6969
_modify { yield &_ThreadLocalState.local.lazyTensorContext }
7070
}
71-
72-
static var operationsTracker: LazyTensorOperationsTracker {
73-
return local.operationsTracker
74-
}
75-
76-
/// A flag that determines whether we should track shapes. We will need to disable shape
77-
/// tracking within certain contexts. e.g., we won't be able to compute shapes when tracing.
78-
static var isShapeTrackingEnabled: Bool {
79-
get { local.isShapeTrackingEnabled }
80-
set { local.isShapeTrackingEnabled = newValue }
81-
}
8271
}

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,19 @@ class LazyTensorHandle: _AnyTensorHandle {
5050
precondition(
5151
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
5252
handle = Handle.symbolic(op, index: index, isLive: false)
53-
LazyTensorContext.operationsTracker.incrementRefCount(op, isLive: false)
53+
LazyTensorContext.local.operationsTracker.incrementRefCount(op, isLive: false)
5454
}
5555

5656
init(_lazyLive op: LazyTensorOperation, index: Int) {
5757
precondition(
5858
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
5959
handle = Handle.symbolic(op, index: index, isLive: true)
60-
LazyTensorContext.operationsTracker.incrementRefCount(op, isLive: true)
60+
LazyTensorContext.local.operationsTracker.incrementRefCount(op, isLive: true)
6161
}
6262

6363
deinit {
6464
if case let .symbolic(op, _, isLive) = handle {
65-
LazyTensorContext.operationsTracker.decrementRefCount(op, isLive: isLive)
65+
LazyTensorContext.local.operationsTracker.decrementRefCount(op, isLive: isLive)
6666
}
6767
}
6868

@@ -80,7 +80,7 @@ class LazyTensorHandle: _AnyTensorHandle {
8080
get {
8181
switch handle {
8282
case .symbolic(let op, let index, _):
83-
precondition(LazyTensorContext.isShapeTrackingEnabled,
83+
precondition(LazyTensorContext.local.isShapeTrackingEnabled,
8484
"Shape tracking is not enabled in this context.")
8585
if let shape = op.outputShapes[index] { return shape }
8686
// Materialize and get the shape from concrete tensor handle.
@@ -102,19 +102,19 @@ class LazyTensorHandle: _AnyTensorHandle {
102102
// Liveness tracking for LazyTensorOperations
103103
//
104104
static func isLive(_ op: LazyTensorOperation) -> Bool {
105-
return LazyTensorContext.operationsTracker.isLive(op)
105+
return LazyTensorContext.local.operationsTracker.isLive(op)
106106
}
107107

108108
static func forEachLiveOperation(
109109
_ perform: (LazyTensorOperation) throws -> Void
110110
) rethrows -> Void {
111-
try LazyTensorContext.operationsTracker.forEachLiveOperation(perform)
111+
try LazyTensorContext.local.operationsTracker.forEachLiveOperation(perform)
112112
}
113113

114114
static func forEachOperation(
115115
_ perform: (LazyTensorOperation) throws -> Void
116116
) rethrows -> Void {
117-
try LazyTensorContext.operationsTracker.forEachOperation(perform)
117+
try LazyTensorContext.local.operationsTracker.forEachOperation(perform)
118118
}
119119

120120
@usableFromInline
@@ -263,7 +263,7 @@ class LazyTensorOperation: TensorOperation {
263263
}
264264

265265
func evaluate() -> [LazyTensorHandle] {
266-
if LazyTensorContext.isShapeTrackingEnabled {
266+
if LazyTensorContext.local.isShapeTrackingEnabled {
267267
updateOutputShapes()
268268
}
269269
return (0..<outputCount).map {

Sources/TensorFlow/Core/LazyTensorTrace.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ class LazyTensorTraceBuilder {
9393
static func trace<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> LazyTensorTrace {
9494
precondition(_ThreadLocalState.useLazyTensor, "Lazy tensor is not enabled for tracing.")
9595
// Disable shape tracking and reset to original state when done.
96-
let isShapeTrackingEnabled = LazyTensorContext.isShapeTrackingEnabled
97-
defer { LazyTensorContext.isShapeTrackingEnabled = isShapeTrackingEnabled }
98-
LazyTensorContext.isShapeTrackingEnabled = false
96+
let isShapeTrackingEnabled = LazyTensorContext.local.isShapeTrackingEnabled
97+
defer { LazyTensorContext.local.isShapeTrackingEnabled = isShapeTrackingEnabled }
98+
LazyTensorContext.local.isShapeTrackingEnabled = false
9999

100100
// Set up inputs for running `fn`.
101101
let inputOps = In._typeList.map { Self.makePlaceholder(dataType: $0) }

0 commit comments

Comments
 (0)