Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 889124f

Browse files
committed
Move liveness tracking logic to LazyTensorOperationsTracker.swift.
1 parent 6b5c521 commit 889124f

File tree

2 files changed

+69
-48
lines changed

2 files changed

+69
-48
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -50,80 +50,40 @@ class LazyTensorHandle: _AnyTensorHandle {
5050
precondition(
5151
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
5252
handle = Handle.symbolic(op, index: index, isLive: false)
53-
LazyTensorHandle.incrementRefCount(op, isLive: false)
53+
Self.operationsTracker.incrementRefCount(op, isLive: false)
5454
}
5555

5656
init(_lazyLive op: LazyTensorOperation, index: Int) {
5757
precondition(
5858
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
5959
handle = Handle.symbolic(op, index: index, isLive: true)
60-
LazyTensorHandle.incrementRefCount(op, isLive: true)
60+
Self.operationsTracker.incrementRefCount(op, isLive: true)
6161
}
6262

6363
deinit {
6464
if case let .symbolic(op, _, isLive) = handle {
65-
LazyTensorHandle.decrementRefCount(op, isLive: isLive)
65+
Self.operationsTracker.decrementRefCount(op, isLive: isLive)
6666
}
6767
}
68-
68+
6969
// Liveness tracking for LazyTensorOperations
7070
//
71-
struct LazyTensorOperationRefCounts {
72-
let op: LazyTensorOperation
73-
let liveRefCount: Int
74-
let allRefCount: Int
75-
}
76-
77-
private static var operationRefCounts: [
78-
ObjectIdentifier: LazyTensorOperationRefCounts] = [:]
79-
80-
static func incrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
81-
let opID = ObjectIdentifier(op)
82-
if let counts = operationRefCounts[opID] {
83-
operationRefCounts[opID] = LazyTensorOperationRefCounts(
84-
op: op,
85-
liveRefCount: counts.liveRefCount + (isLive ? 1 : 0),
86-
allRefCount: counts.allRefCount + 1)
87-
} else {
88-
operationRefCounts[opID] = LazyTensorOperationRefCounts(
89-
op: op, liveRefCount: isLive ? 1 : 0, allRefCount: 1)
90-
}
91-
}
92-
93-
static func decrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
94-
let opID = ObjectIdentifier(op)
95-
if let counts = operationRefCounts[opID] {
96-
if counts.allRefCount > 1 {
97-
operationRefCounts[opID] = LazyTensorOperationRefCounts(
98-
op: op,
99-
liveRefCount: counts.liveRefCount - (isLive ? 1 : 0),
100-
allRefCount: counts.allRefCount - 1)
101-
} else {
102-
operationRefCounts.removeValue(forKey: opID)
103-
}
104-
}
105-
}
71+
private static var operationsTracker = LazyTensorOperationsTracker()
10672

10773
static func isLive(_ op: LazyTensorOperation) -> Bool {
108-
let opID = ObjectIdentifier(op)
109-
if let counts = operationRefCounts[opID] {
110-
return counts.liveRefCount > 0
111-
}
112-
return false
74+
return operationsTracker.isLive(op)
11375
}
11476

11577
static func forEachLiveOperation(
11678
_ perform: (LazyTensorOperation) throws -> Void
11779
) rethrows -> Void {
118-
for (_, counts) in operationRefCounts where counts.liveRefCount > 0 {
119-
try perform(counts.op)
120-
}
80+
try operationsTracker.forEachLiveOperation(perform)
12181
}
12282

12383
static func forEachOperation(
12484
_ perform: (LazyTensorOperation) throws -> Void
12585
) rethrows -> Void {
126-
for (_, counts) in operationRefCounts { try perform(counts.op) }
86+
try operationsTracker.forEachOperation(perform)
12787
}
12888

12989
@usableFromInline
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/// A class to keep track of runtime information about `LazyTensorOperation`
2+
/// instances created by the program. This will be manaaged as a thread local
3+
/// state.
4+
class LazyTensorOperationsTracker {
5+
struct RefCounts {
6+
let op: LazyTensorOperation
7+
let liveRefCount: Int
8+
let allRefCount: Int
9+
}
10+
11+
private var refCounts: [ObjectIdentifier: RefCounts] = [:]
12+
13+
func incrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
14+
let opID = ObjectIdentifier(op)
15+
if let counts = refCounts[opID] {
16+
refCounts[opID] = RefCounts(
17+
op: op,
18+
liveRefCount: counts.liveRefCount + (isLive ? 1 : 0),
19+
allRefCount: counts.allRefCount + 1)
20+
} else {
21+
refCounts[opID] = RefCounts(
22+
op: op, liveRefCount: isLive ? 1 : 0, allRefCount: 1)
23+
}
24+
}
25+
26+
func decrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
27+
let opID = ObjectIdentifier(op)
28+
if let counts = refCounts[opID] {
29+
if counts.allRefCount > 1 {
30+
refCounts[opID] = RefCounts(
31+
op: op,
32+
liveRefCount: counts.liveRefCount - (isLive ? 1 : 0),
33+
allRefCount: counts.allRefCount - 1)
34+
} else {
35+
refCounts.removeValue(forKey: opID)
36+
}
37+
}
38+
}
39+
40+
func isLive(_ op: LazyTensorOperation) -> Bool {
41+
let opID = ObjectIdentifier(op)
42+
if let counts = refCounts[opID] {
43+
return counts.liveRefCount > 0
44+
}
45+
return false
46+
}
47+
48+
func forEachLiveOperation(
49+
_ perform: (LazyTensorOperation) throws -> Void
50+
) rethrows -> Void {
51+
for (_, counts) in refCounts where counts.liveRefCount > 0 {
52+
try perform(counts.op)
53+
}
54+
}
55+
56+
func forEachOperation(
57+
_ perform: (LazyTensorOperation) throws -> Void
58+
) rethrows -> Void {
59+
for (_, counts) in refCounts { try perform(counts.op) }
60+
}
61+
}

0 commit comments

Comments
 (0)