Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 97ca18b

Browse files
authored
Make liveness tracking of LazyTensorOperations a thread local state. (#372)
* Move liveness tracking logic to LazyTensorOperationsTracker.swift. * Make Liveness tracking a thread local state. * Renaming LazyTensorOperationsTracker -> LazyTensorContext * Update comment on Runtime.swift. * fix a typo.
1 parent 3fdaa61 commit 97ca18b

File tree

3 files changed

+85
-50
lines changed

3 files changed

+85
-50
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/// A class to keep track of runtime information about `LazyTensorOperation`
2+
/// instances created by the program. This will be managed as a thread local
3+
/// state.
4+
class LazyTensorOperationsTracker {
5+
struct RefCounts {
6+
let op: LazyTensorOperation
7+
let liveRefCount: Int
8+
let allRefCount: Int
9+
}
10+
11+
private var refCounts: [ObjectIdentifier: RefCounts] = [:]
12+
13+
func incrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
14+
let opID = ObjectIdentifier(op)
15+
if let counts = refCounts[opID] {
16+
refCounts[opID] = RefCounts(
17+
op: op,
18+
liveRefCount: counts.liveRefCount + (isLive ? 1 : 0),
19+
allRefCount: counts.allRefCount + 1)
20+
} else {
21+
refCounts[opID] = RefCounts(
22+
op: op, liveRefCount: isLive ? 1 : 0, allRefCount: 1)
23+
}
24+
}
25+
26+
func decrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
27+
let opID = ObjectIdentifier(op)
28+
if let counts = refCounts[opID] {
29+
if counts.allRefCount > 1 {
30+
refCounts[opID] = RefCounts(
31+
op: op,
32+
liveRefCount: counts.liveRefCount - (isLive ? 1 : 0),
33+
allRefCount: counts.allRefCount - 1)
34+
} else {
35+
refCounts.removeValue(forKey: opID)
36+
}
37+
}
38+
}
39+
40+
func isLive(_ op: LazyTensorOperation) -> Bool {
41+
let opID = ObjectIdentifier(op)
42+
if let counts = refCounts[opID] {
43+
return counts.liveRefCount > 0
44+
}
45+
return false
46+
}
47+
48+
func forEachLiveOperation(
49+
_ perform: (LazyTensorOperation) throws -> Void
50+
) rethrows -> Void {
51+
for (_, counts) in refCounts where counts.liveRefCount > 0 {
52+
try perform(counts.op)
53+
}
54+
}
55+
56+
func forEachOperation(
57+
_ perform: (LazyTensorOperation) throws -> Void
58+
) rethrows -> Void {
59+
for (_, counts) in refCounts { try perform(counts.op) }
60+
}
61+
}
62+
63+
struct LazyTensorContext {
64+
var scopes = [LazyTensorOperationsTracker()]
65+
66+
static private var threadLocalContext: LazyTensorContext {
67+
_ThreadLocalState.local.lazyTensorContext
68+
}
69+
70+
static var operationsTracker: LazyTensorOperationsTracker {
71+
return threadLocalContext.scopes.last!
72+
}
73+
}

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -50,80 +50,38 @@ class LazyTensorHandle: _AnyTensorHandle {
5050
precondition(
5151
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
5252
handle = Handle.symbolic(op, index: index, isLive: false)
53-
LazyTensorHandle.incrementRefCount(op, isLive: false)
53+
LazyTensorContext.operationsTracker.incrementRefCount(op, isLive: false)
5454
}
5555

5656
init(_lazyLive op: LazyTensorOperation, index: Int) {
5757
precondition(
5858
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
5959
handle = Handle.symbolic(op, index: index, isLive: true)
60-
LazyTensorHandle.incrementRefCount(op, isLive: true)
60+
LazyTensorContext.operationsTracker.incrementRefCount(op, isLive: true)
6161
}
6262

6363
deinit {
6464
if case let .symbolic(op, _, isLive) = handle {
65-
LazyTensorHandle.decrementRefCount(op, isLive: isLive)
65+
LazyTensorContext.operationsTracker.decrementRefCount(op, isLive: isLive)
6666
}
6767
}
68-
68+
6969
// Liveness tracking for LazyTensorOperations
7070
//
71-
struct LazyTensorOperationRefCounts {
72-
let op: LazyTensorOperation
73-
let liveRefCount: Int
74-
let allRefCount: Int
75-
}
76-
77-
private static var operationRefCounts: [
78-
ObjectIdentifier: LazyTensorOperationRefCounts] = [:]
79-
80-
static func incrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
81-
let opID = ObjectIdentifier(op)
82-
if let counts = operationRefCounts[opID] {
83-
operationRefCounts[opID] = LazyTensorOperationRefCounts(
84-
op: op,
85-
liveRefCount: counts.liveRefCount + (isLive ? 1 : 0),
86-
allRefCount: counts.allRefCount + 1)
87-
} else {
88-
operationRefCounts[opID] = LazyTensorOperationRefCounts(
89-
op: op, liveRefCount: isLive ? 1 : 0, allRefCount: 1)
90-
}
91-
}
92-
93-
static func decrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
94-
let opID = ObjectIdentifier(op)
95-
if let counts = operationRefCounts[opID] {
96-
if counts.allRefCount > 1 {
97-
operationRefCounts[opID] = LazyTensorOperationRefCounts(
98-
op: op,
99-
liveRefCount: counts.liveRefCount - (isLive ? 1 : 0),
100-
allRefCount: counts.allRefCount - 1)
101-
} else {
102-
operationRefCounts.removeValue(forKey: opID)
103-
}
104-
}
105-
}
106-
10771
static func isLive(_ op: LazyTensorOperation) -> Bool {
108-
let opID = ObjectIdentifier(op)
109-
if let counts = operationRefCounts[opID] {
110-
return counts.liveRefCount > 0
111-
}
112-
return false
72+
return LazyTensorContext.operationsTracker.isLive(op)
11373
}
11474

11575
static func forEachLiveOperation(
11676
_ perform: (LazyTensorOperation) throws -> Void
11777
) rethrows -> Void {
118-
for (_, counts) in operationRefCounts where counts.liveRefCount > 0 {
119-
try perform(counts.op)
120-
}
78+
try LazyTensorContext.operationsTracker.forEachLiveOperation(perform)
12179
}
12280

12381
static func forEachOperation(
12482
_ perform: (LazyTensorOperation) throws -> Void
12583
) rethrows -> Void {
126-
for (_, counts) in operationRefCounts { try perform(counts.op) }
84+
try LazyTensorContext.operationsTracker.forEachOperation(perform)
12785
}
12886

12987
@usableFromInline

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,10 +1193,14 @@ func _TFCOpSetAttrTypeArray(
11931193
}
11941194
}
11951195

1196-
/// A class to keep around thread local state.
1196+
/// A class to keep around thread local state:
1197+
/// - DeviceScopes
1198+
/// - LazyTensorContext
11971199
class _ThreadLocalState {
11981200
var deviceScopes = DeviceScopes()
11991201

1202+
var lazyTensorContext = LazyTensorContext()
1203+
12001204
private static let key: pthread_key_t = {
12011205
var key = pthread_key_t()
12021206
pthread_key_create(&key) {

0 commit comments

Comments
 (0)