Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit f9c63e4

Browse files
committed
Liveness tracking in lazy tensor
1 parent dbeeb06 commit f9c63e4

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

Sources/TensorFlow/LazyTensor/LazyTensorOperation.swift

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,83 @@ class LazyTensor: _AnyTensorHandle {
3636
precondition(
3737
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
3838
handle = Handle.symbolic(op, index: index, isLive: false)
39+
LazyTensor.incrementRefCount(op, isLive: false)
3940
}
4041

4142
init(_lazyLive op: LazyTensorOperation, index: Int) {
4243
precondition(
4344
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
4445
handle = Handle.symbolic(op, index: index, isLive: true)
46+
LazyTensor.incrementRefCount(op, isLive: true)
47+
}
48+
49+
deinit {
50+
if case let .symbolic(op, _, isLive) = handle {
51+
LazyTensor.decrementRefCount(op, isLive: isLive)
52+
}
53+
}
54+
55+
// Liveness tracking for LazyTensorOperations
56+
//
57+
struct LazyTensorOperationRefCounts {
58+
let op: LazyTensorOperation
59+
let live: Int
60+
let all: Int
61+
}
62+
63+
private static var operationRefCounts: [
64+
ObjectIdentifier: LazyTensorOperationRefCounts] = [:]
65+
66+
static func incrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
67+
let opId = ObjectIdentifier(op)
68+
if let counts = operationRefCounts[opId] {
69+
operationRefCounts[opId] = LazyTensorOperationRefCounts(
70+
op: op,
71+
live: isLive ? counts.live + 1 : counts.live,
72+
all: counts.all + 1)
73+
} else {
74+
operationRefCounts[opId] = LazyTensorOperationRefCounts(
75+
op: op, live: isLive ? 1 : 0, all: 1)
76+
}
77+
}
78+
79+
static func decrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
80+
let opId = ObjectIdentifier(op)
81+
if let counts = operationRefCounts[opId] {
82+
if counts.all > 1 {
83+
operationRefCounts[opId] = LazyTensorOperationRefCounts(
84+
op: op,
85+
live: isLive ? counts.live - 1 : counts.live,
86+
all: counts.all - 1)
87+
} else {
88+
operationRefCounts.removeValue(forKey: opId)
89+
}
90+
}
91+
}
92+
93+
static func isLive(_ op: LazyTensorOperation) -> Bool {
94+
let opId = ObjectIdentifier(op)
95+
if let counts = operationRefCounts[opId] {
96+
return counts.live > 0
97+
}
98+
return false
99+
}
100+
101+
static func onLiveOperations(_ perform: (LazyTensorOperation) -> ()) {
102+
for (_, counts) in operationRefCounts {
103+
if (counts.live > 0) { perform(counts.op) }
104+
}
105+
}
106+
107+
static func onAllOperations(_ perform: (LazyTensorOperation) -> ()) {
108+
for (_, counts) in operationRefCounts { perform(counts.op) }
109+
}
110+
111+
public static func printRefCounts() {
112+
let live = operationRefCounts.values.reduce(0, { (sum, element) in
113+
return sum + (element.live > 0 ? 1 : 0)
114+
})
115+
print("LazyTensorOperations: \(operationRefCounts.count) (\(live) live)")
45116
}
46117

47118
static var _materializationCallback: (String) -> () = { _ in }
@@ -85,19 +156,26 @@ class LazyTensorOperation: TensorOperation {
85156
}
86157
}
87158

159+
public static var liveOperations: Int = 0
160+
88161
init(_id id: String?, name: String, outputCount: Int) {
89162
self.name = name
90163
self.inputs = []
91164
self.attrs = [:]
92165
self.outputCount = outputCount
93166
self.outputs = nil
94167
self.id = id
168+
LazyTensorOperation.liveOperations += 1
95169
}
96170

97171
required convenience init(_ name: String, _ outputCount: Int) {
98172
self.init(_id: nil, name: name, outputCount: outputCount)
99173
}
100174

175+
deinit {
176+
LazyTensorOperation.liveOperations -= 1
177+
}
178+
101179
func evaluate() -> [LazyTensor] {
102180
return (0..<outputCount).map {
103181
LazyTensor(_lazyLive: self, index: $0)

0 commit comments

Comments
 (0)