@@ -36,12 +36,72 @@ class LazyTensor: _AnyTensorHandle {
36
36
precondition (
37
37
index < op. outputCount, " Symbolic Tensor Index is out-of-bounds " )
38
38
handle = Handle . symbolic ( op, index: index, isLive: false )
39
+ LazyTensor . incrementRefCount ( op, isLive: false )
39
40
}
40
41
41
42
init ( _lazyLive op: LazyTensorOperation , index: Int ) {
42
43
precondition (
43
44
index < op. outputCount, " Symbolic Tensor Index is out-of-bounds " )
44
45
handle = Handle . symbolic ( op, index: index, isLive: true )
46
+ LazyTensor . incrementRefCount ( op, isLive: true )
47
+ }
48
+
49
+ deinit {
50
+ if case let . symbolic( op, _, isLive) = handle {
51
+ LazyTensor . decrementRefCount ( op, isLive: isLive)
52
+ }
53
+ }
54
+
55
+ // Liveness tracking for LazyTensorOperations
56
+ //
57
+ struct LazyTensorOperationRefCounts {
58
+ let op : LazyTensorOperation
59
+ let liveRefCount : Int
60
+ let allRefCount : Int
61
+ }
62
+
63
+ private static var operationRefCounts : [
64
+ ObjectIdentifier : LazyTensorOperationRefCounts ] = [ : ]
65
+
66
+ static func incrementRefCount( _ op: LazyTensorOperation , isLive: Bool ) {
67
+ let opID = ObjectIdentifier ( op)
68
+ if let counts = operationRefCounts [ opID] {
69
+ operationRefCounts [ opID] = LazyTensorOperationRefCounts (
70
+ op: op,
71
+ liveRefCount: counts. liveRefCount + ( isLive ? 1 : 0 ) ,
72
+ allRefCount: counts. allRefCount + 1 )
73
+ } else {
74
+ operationRefCounts [ opID] = LazyTensorOperationRefCounts (
75
+ op: op, liveRefCount: isLive ? 1 : 0 , allRefCount: 1 )
76
+ }
77
+ }
78
+
79
+ static func decrementRefCount( _ op: LazyTensorOperation , isLive: Bool ) {
80
+ let opID = ObjectIdentifier ( op)
81
+ if let counts = operationRefCounts [ opID] {
82
+ if counts. allRefCount > 1 {
83
+ operationRefCounts [ opID] = LazyTensorOperationRefCounts (
84
+ op: op,
85
+ liveRefCount: counts. liveRefCount - ( isLive ? 1 : 0 ) ,
86
+ allRefCount: counts. allRefCount - 1 )
87
+ } else {
88
+ operationRefCounts. removeValue ( forKey: opID)
89
+ }
90
+ }
91
+ }
92
+
93
+ static func isLive( _ op: LazyTensorOperation ) -> Bool {
94
+ let opID = ObjectIdentifier ( op)
95
+ if let counts = operationRefCounts [ opID] {
96
+ return counts. liveRefCount > 0
97
+ }
98
+ return false
99
+ }
100
+
101
+ static func onLiveOperations( _ perform: ( LazyTensorOperation ) -> ( ) ) {
102
+ for (_, counts) in operationRefCounts where counts. liveRefCount > 0 {
103
+ perform ( counts. op)
104
+ }
45
105
}
46
106
47
107
static var _materializationCallback : ( String ) -> ( ) = { _ in }
@@ -85,19 +145,26 @@ class LazyTensorOperation: TensorOperation {
85
145
}
86
146
}
87
147
148
+ static var liveOperations : Int = 0
149
+
88
150
init ( _id id: String ? , name: String , outputCount: Int ) {
89
151
self . name = name
90
152
self . inputs = [ ]
91
153
self . attrs = [ : ]
92
154
self . outputCount = outputCount
93
155
self . outputs = nil
94
156
self . id = id
157
+ LazyTensorOperation . liveOperations += 1
95
158
}
96
159
97
160
required convenience init ( _ name: String , _ outputCount: Int ) {
98
161
self . init ( _id: nil , name: name, outputCount: outputCount)
99
162
}
100
163
164
+ deinit {
165
+ LazyTensorOperation . liveOperations -= 1
166
+ }
167
+
101
168
func evaluate( ) -> [ LazyTensor ] {
102
169
return ( 0 ..< outputCount) . map {
103
170
LazyTensor ( _lazyLive: self , index: $0)
0 commit comments