Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit f7eefd7

Browse files
committed
Add liveness tracking tests.
1 parent f9c63e4 commit f7eefd7

File tree

2 files changed

+70
-16
lines changed

2 files changed

+70
-16
lines changed

Sources/TensorFlow/LazyTensor/LazyTensorOperation.swift

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class LazyTensor: _AnyTensorHandle {
5656
//
5757
struct LazyTensorOperationRefCounts {
5858
let op: LazyTensorOperation
59-
let live: Int
60-
let all: Int
59+
let liveRefCount: Int
60+
let allRefCount: Int
6161
}
6262

6363
private static var operationRefCounts: [
@@ -68,22 +68,22 @@ class LazyTensor: _AnyTensorHandle {
6868
if let counts = operationRefCounts[opId] {
6969
operationRefCounts[opId] = LazyTensorOperationRefCounts(
7070
op: op,
71-
live: isLive ? counts.live + 1 : counts.live,
72-
all: counts.all + 1)
71+
liveRefCount: counts.liveRefCount + (isLive ? 1 : 0),
72+
allRefCount: counts.allRefCount + 1)
7373
} else {
7474
operationRefCounts[opId] = LazyTensorOperationRefCounts(
75-
op: op, live: isLive ? 1 : 0, all: 1)
75+
op: op, liveRefCount: isLive ? 1 : 0, allRefCount: 1)
7676
}
7777
}
7878

7979
static func decrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
8080
let opId = ObjectIdentifier(op)
8181
if let counts = operationRefCounts[opId] {
82-
if counts.all > 1 {
82+
if counts.allRefCount > 1 {
8383
operationRefCounts[opId] = LazyTensorOperationRefCounts(
8484
op: op,
85-
live: isLive ? counts.live - 1 : counts.live,
86-
all: counts.all - 1)
85+
liveRefCount: counts.liveRefCount - (isLive ? 1 : 0),
86+
allRefCount: counts.allRefCount - 1)
8787
} else {
8888
operationRefCounts.removeValue(forKey: opId)
8989
}
@@ -93,24 +93,20 @@ class LazyTensor: _AnyTensorHandle {
9393
static func isLive(_ op: LazyTensorOperation) -> Bool {
9494
let opId = ObjectIdentifier(op)
9595
if let counts = operationRefCounts[opId] {
96-
return counts.live > 0
96+
return counts.liveRefCount > 0
9797
}
9898
return false
9999
}
100100

101101
static func onLiveOperations(_ perform: (LazyTensorOperation) -> ()) {
102102
for (_, counts) in operationRefCounts {
103-
if (counts.live > 0) { perform(counts.op) }
103+
if (counts.liveRefCount > 0) { perform(counts.op) }
104104
}
105105
}
106106

107-
static func onAllOperations(_ perform: (LazyTensorOperation) -> ()) {
108-
for (_, counts) in operationRefCounts { perform(counts.op) }
109-
}
110-
111-
public static func printRefCounts() {
107+
static func printRefCounts() {
112108
let live = operationRefCounts.values.reduce(0, { (sum, element) in
113-
return sum + (element.live > 0 ? 1 : 0)
109+
return sum + (element.liveRefCount > 0 ? 1 : 0)
114110
})
115111
print("LazyTensorOperations: \(operationRefCounts.count) (\(live) live)")
116112
}

Tests/TensorFlowTests/LazyTensorTests.swift

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@ import XCTest
1616
@testable import TensorFlow
1717
import CTensorFlow
1818

19+
struct LazyTensorOperationRef : Equatable, Hashable {
20+
let value: LazyTensorOperation
21+
init(_ value: LazyTensorOperation) { self.value = value }
22+
func hash(into hasher: inout Hasher) {
23+
hasher.combine(ObjectIdentifier(value))
24+
}
25+
}
26+
27+
func ==(lhs: LazyTensorOperationRef, rhs: LazyTensorOperationRef) -> Bool {
28+
return lhs.value === rhs.value
29+
}
30+
1931
final class LazyTensorTests: XCTestCase {
2032
func testConstructions() {
2133
let zero = Tensor<Float>(0.0)
@@ -40,8 +52,54 @@ final class LazyTensorTests: XCTestCase {
4052
XCTAssertEqual("\(liveSymTensor)", "IdentityN_0:0*")
4153
}
4254

55+
func testLivenessTracking() {
56+
func assertLive(_ expectedLive: [LazyTensorOperation]) {
57+
var actualLiveOps: Set<LazyTensorOperationRef> = []
58+
LazyTensor.onLiveOperations {
59+
actualLiveOps.insert(LazyTensorOperationRef($0))
60+
}
61+
let expectedLiveOps = Set<LazyTensorOperationRef>(
62+
expectedLive.map { LazyTensorOperationRef($0) }
63+
)
64+
XCTAssertEqual(expectedLiveOps, actualLiveOps)
65+
}
66+
func isSymbolic(_ t: LazyTensor) -> Bool {
67+
if case let .symbolic(_) = t.handle {
68+
return true
69+
} else {
70+
return false
71+
}
72+
}
73+
74+
let op0 = LazyTensorOperation(
75+
_id: "0", name: "IdentityN", outputCount: 2)
76+
let op1 = LazyTensorOperation(
77+
_id: "1", name: "IdentityN", outputCount: 2)
78+
79+
XCTAssertFalse(LazyTensor.isLive(op0))
80+
XCTAssertFalse(LazyTensor.isLive(op1))
81+
82+
let t0 = LazyTensor(_lazyLive: op0, index: 0)
83+
let t1 = LazyTensor(_lazy: op1, index: 1)
84+
XCTAssertTrue(LazyTensor.isLive(op0))
85+
XCTAssertFalse(LazyTensor.isLive(op1))
86+
do {
87+
let t3 = LazyTensor(_lazyLive: op1, index: 0)
88+
XCTAssertTrue(LazyTensor.isLive(op1))
89+
assertLive([op0, op1])
90+
// The following are here just to ensure t3 is live.
91+
XCTAssertTrue(isSymbolic(t3))
92+
}
93+
XCTAssertFalse(LazyTensor.isLive(op1))
94+
assertLive([op0])
95+
// The following are here just to ensure t0 and t1 are live.
96+
XCTAssertTrue(isSymbolic(t1))
97+
XCTAssertTrue(isSymbolic(t0))
98+
}
99+
43100
static var allTests = [
44101
("testConstructions", testConstructions),
102+
("testLivenessTracking", testLivenessTracking),
45103
]
46104

47105
}

0 commit comments

Comments
 (0)