Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 13e0cd6

Browse files
committed
Make Liveness tracking a thread local state.
1 parent 889124f commit 13e0cd6

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,40 +50,38 @@ class LazyTensorHandle: _AnyTensorHandle {
5050
precondition(
5151
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
5252
handle = Handle.symbolic(op, index: index, isLive: false)
53-
Self.operationsTracker.incrementRefCount(op, isLive: false)
53+
LazyTensorContext.operationsTracker.incrementRefCount(op, isLive: false)
5454
}
5555

5656
init(_lazyLive op: LazyTensorOperation, index: Int) {
5757
precondition(
5858
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
5959
handle = Handle.symbolic(op, index: index, isLive: true)
60-
Self.operationsTracker.incrementRefCount(op, isLive: true)
60+
LazyTensorContext.operationsTracker.incrementRefCount(op, isLive: true)
6161
}
6262

6363
deinit {
6464
if case let .symbolic(op, _, isLive) = handle {
65-
Self.operationsTracker.decrementRefCount(op, isLive: isLive)
65+
LazyTensorContext.operationsTracker.decrementRefCount(op, isLive: isLive)
6666
}
6767
}
6868

6969
// Liveness tracking for LazyTensorOperations
7070
//
71-
private static var operationsTracker = LazyTensorOperationsTracker()
72-
7371
static func isLive(_ op: LazyTensorOperation) -> Bool {
74-
return operationsTracker.isLive(op)
72+
return LazyTensorContext.operationsTracker.isLive(op)
7573
}
7674

7775
static func forEachLiveOperation(
7876
_ perform: (LazyTensorOperation) throws -> Void
7977
) rethrows -> Void {
80-
try operationsTracker.forEachLiveOperation(perform)
78+
try LazyTensorContext.operationsTracker.forEachLiveOperation(perform)
8179
}
8280

8381
static func forEachOperation(
8482
_ perform: (LazyTensorOperation) throws -> Void
8583
) rethrows -> Void {
86-
try operationsTracker.forEachOperation(perform)
84+
try LazyTensorContext.operationsTracker.forEachOperation(perform)
8785
}
8886

8987
@usableFromInline

Sources/TensorFlow/Core/LazyTensorOperationsTracker.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,15 @@ class LazyTensorOperationsTracker {
5959
for (_, counts) in refCounts { try perform(counts.op) }
6060
}
6161
}
62+
63+
struct LazyTensorContext {
64+
var scopes = [LazyTensorOperationsTracker()]
65+
66+
static private var threadLocalContext: LazyTensorContext {
67+
_ThreadLocalState.local.lazyTensorContext
68+
}
69+
70+
static var operationsTracker: LazyTensorOperationsTracker {
71+
return threadLocalContext.scopes.last!
72+
}
73+
}

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,8 @@ fileprivate func setAttrShapeList(
12371237
class _ThreadLocalState {
12381238
var deviceScopes = DeviceScopes()
12391239

1240+
var lazyTensorContext = LazyTensorContext()
1241+
12401242
private static let key: pthread_key_t = {
12411243
var key = pthread_key_t()
12421244
pthread_key_create(&key) {

0 commit comments

Comments
 (0)