Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit d538aab

Browse files
authored
Add liveness tracking to LazyTensor (#190)
1 parent 4c0a09c commit d538aab

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

Sources/TensorFlow/LazyTensor/LazyTensorOperation.swift

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,72 @@ class LazyTensor: _AnyTensorHandle {
3636
precondition(
3737
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
3838
handle = Handle.symbolic(op, index: index, isLive: false)
39+
LazyTensor.incrementRefCount(op, isLive: false)
3940
}
4041

4142
init(_lazyLive op: LazyTensorOperation, index: Int) {
4243
precondition(
4344
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
4445
handle = Handle.symbolic(op, index: index, isLive: true)
46+
LazyTensor.incrementRefCount(op, isLive: true)
47+
}
48+
49+
deinit {
50+
if case let .symbolic(op, _, isLive) = handle {
51+
LazyTensor.decrementRefCount(op, isLive: isLive)
52+
}
53+
}
54+
55+
// Liveness tracking for LazyTensorOperations
56+
//
57+
struct LazyTensorOperationRefCounts {
58+
let op: LazyTensorOperation
59+
let liveRefCount: Int
60+
let allRefCount: Int
61+
}
62+
63+
private static var operationRefCounts: [
64+
ObjectIdentifier: LazyTensorOperationRefCounts] = [:]
65+
66+
static func incrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
67+
let opID = ObjectIdentifier(op)
68+
if let counts = operationRefCounts[opID] {
69+
operationRefCounts[opID] = LazyTensorOperationRefCounts(
70+
op: op,
71+
liveRefCount: counts.liveRefCount + (isLive ? 1 : 0),
72+
allRefCount: counts.allRefCount + 1)
73+
} else {
74+
operationRefCounts[opID] = LazyTensorOperationRefCounts(
75+
op: op, liveRefCount: isLive ? 1 : 0, allRefCount: 1)
76+
}
77+
}
78+
79+
static func decrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
80+
let opID = ObjectIdentifier(op)
81+
if let counts = operationRefCounts[opID] {
82+
if counts.allRefCount > 1 {
83+
operationRefCounts[opID] = LazyTensorOperationRefCounts(
84+
op: op,
85+
liveRefCount: counts.liveRefCount - (isLive ? 1 : 0),
86+
allRefCount: counts.allRefCount - 1)
87+
} else {
88+
operationRefCounts.removeValue(forKey: opID)
89+
}
90+
}
91+
}
92+
93+
static func isLive(_ op: LazyTensorOperation) -> Bool {
94+
let opID = ObjectIdentifier(op)
95+
if let counts = operationRefCounts[opID] {
96+
return counts.liveRefCount > 0
97+
}
98+
return false
99+
}
100+
101+
static func onLiveOperations(_ perform: (LazyTensorOperation) -> ()) {
102+
for (_, counts) in operationRefCounts where counts.liveRefCount > 0 {
103+
perform(counts.op)
104+
}
45105
}
46106

47107
static var _materializationCallback: (String) -> () = { _ in }
@@ -85,19 +145,26 @@ class LazyTensorOperation: TensorOperation {
85145
}
86146
}
87147

148+
static var liveOperations: Int = 0
149+
88150
init(_id id: String?, name: String, outputCount: Int) {
89151
self.name = name
90152
self.inputs = []
91153
self.attrs = [:]
92154
self.outputCount = outputCount
93155
self.outputs = nil
94156
self.id = id
157+
LazyTensorOperation.liveOperations += 1
95158
}
96159

97160
required convenience init(_ name: String, _ outputCount: Int) {
98161
self.init(_id: nil, name: name, outputCount: outputCount)
99162
}
100163

164+
deinit {
165+
LazyTensorOperation.liveOperations -= 1
166+
}
167+
101168
func evaluate() -> [LazyTensor] {
102169
return (0..<outputCount).map {
103170
LazyTensor(_lazyLive: self, index: $0)

Tests/TensorFlowTests/LazyTensorTests.swift

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@ import XCTest
1616
@testable import TensorFlow
1717
import CTensorFlow
1818

19+
struct LazyTensorOperationRef: Equatable, Hashable {
20+
let value: LazyTensorOperation
21+
init(_ value: LazyTensorOperation) { self.value = value }
22+
func hash(into hasher: inout Hasher) {
23+
hasher.combine(ObjectIdentifier(value))
24+
}
25+
}
26+
27+
func ==(lhs: LazyTensorOperationRef, rhs: LazyTensorOperationRef) -> Bool {
28+
return lhs.value === rhs.value
29+
}
30+
1931
final class LazyTensorTests: XCTestCase {
2032
func testConstructions() {
2133
let zero = Tensor<Float>(0.0)
@@ -40,8 +52,57 @@ final class LazyTensorTests: XCTestCase {
4052
XCTAssertEqual("\(liveSymTensor)", "IdentityN_0:0*")
4153
}
4254

55+
func testLivenessTracking() {
56+
func assertLive(_ expectedLive: [LazyTensorOperation]) {
57+
var actualLiveOps: Set<LazyTensorOperationRef> = []
58+
LazyTensor.onLiveOperations {
59+
actualLiveOps.insert(LazyTensorOperationRef($0))
60+
}
61+
let expectedLiveOps = Set<LazyTensorOperationRef>(
62+
expectedLive.map { LazyTensorOperationRef($0) }
63+
)
64+
XCTAssertEqual(expectedLiveOps, actualLiveOps)
65+
}
66+
67+
func isSymbolic(_ t: LazyTensor) -> Bool {
68+
if case let .symbolic(_) = t.handle {
69+
return true
70+
} else {
71+
return false
72+
}
73+
}
74+
75+
let op0 = LazyTensorOperation(
76+
_id: "0", name: "IdentityN", outputCount: 2)
77+
let op1 = LazyTensorOperation(
78+
_id: "1", name: "IdentityN", outputCount: 2)
79+
80+
XCTAssertFalse(LazyTensor.isLive(op0))
81+
XCTAssertFalse(LazyTensor.isLive(op1))
82+
83+
let t0 = LazyTensor(_lazyLive: op0, index: 0)
84+
let t1 = LazyTensor(_lazy: op1, index: 1)
85+
XCTAssertTrue(LazyTensor.isLive(op0))
86+
XCTAssertFalse(LazyTensor.isLive(op1))
87+
88+
do {
89+
let t3 = LazyTensor(_lazyLive: op1, index: 0)
90+
XCTAssertTrue(LazyTensor.isLive(op1))
91+
assertLive([op0, op1])
92+
// The following is here just to ensure t3 is live.
93+
XCTAssertTrue(isSymbolic(t3))
94+
}
95+
XCTAssertFalse(LazyTensor.isLive(op1))
96+
assertLive([op0])
97+
98+
// The following are here just to ensure t0 and t1 are live.
99+
XCTAssertTrue(isSymbolic(t1))
100+
XCTAssertTrue(isSymbolic(t0))
101+
}
102+
43103
static var allTests = [
44104
("testConstructions", testConstructions),
105+
("testLivenessTracking", testLivenessTracking),
45106
]
46107

47108
}

0 commit comments

Comments
 (0)