@@ -36,12 +36,83 @@ class LazyTensor: _AnyTensorHandle {
36
36
precondition (
37
37
index < op. outputCount, " Symbolic Tensor Index is out-of-bounds " )
38
38
handle = Handle . symbolic ( op, index: index, isLive: false )
39
+ LazyTensor . incrementRefCount ( op, isLive: false )
39
40
}
40
41
41
42
init ( _lazyLive op: LazyTensorOperation , index: Int ) {
42
43
precondition (
43
44
index < op. outputCount, " Symbolic Tensor Index is out-of-bounds " )
44
45
handle = Handle . symbolic ( op, index: index, isLive: true )
46
+ LazyTensor . incrementRefCount ( op, isLive: true )
47
+ }
48
+
49
+ deinit {
50
+ if case let . symbolic( op, _, isLive) = handle {
51
+ LazyTensor . decrementRefCount ( op, isLive: isLive)
52
+ }
53
+ }
54
+
55
+ // Liveness tracking for LazyTensorOperations
56
+ //
57
+ struct LazyTensorOperationRefCounts {
58
+ let op : LazyTensorOperation
59
+ let live : Int
60
+ let all : Int
61
+ }
62
+
63
+ private static var operationRefCounts : [
64
+ ObjectIdentifier : LazyTensorOperationRefCounts ] = [ : ]
65
+
66
+ static func incrementRefCount( _ op: LazyTensorOperation , isLive: Bool ) {
67
+ let opId = ObjectIdentifier ( op)
68
+ if let counts = operationRefCounts [ opId] {
69
+ operationRefCounts [ opId] = LazyTensorOperationRefCounts (
70
+ op: op,
71
+ live: isLive ? counts. live + 1 : counts. live,
72
+ all: counts. all + 1 )
73
+ } else {
74
+ operationRefCounts [ opId] = LazyTensorOperationRefCounts (
75
+ op: op, live: isLive ? 1 : 0 , all: 1 )
76
+ }
77
+ }
78
+
79
+ static func decrementRefCount( _ op: LazyTensorOperation , isLive: Bool ) {
80
+ let opId = ObjectIdentifier ( op)
81
+ if let counts = operationRefCounts [ opId] {
82
+ if counts. all > 1 {
83
+ operationRefCounts [ opId] = LazyTensorOperationRefCounts (
84
+ op: op,
85
+ live: isLive ? counts. live - 1 : counts. live,
86
+ all: counts. all - 1 )
87
+ } else {
88
+ operationRefCounts. removeValue ( forKey: opId)
89
+ }
90
+ }
91
+ }
92
+
93
+ static func isLive( _ op: LazyTensorOperation ) -> Bool {
94
+ let opId = ObjectIdentifier ( op)
95
+ if let counts = operationRefCounts [ opId] {
96
+ return counts. live > 0
97
+ }
98
+ return false
99
+ }
100
+
101
+ static func onLiveOperations( _ perform: ( LazyTensorOperation ) -> ( ) ) {
102
+ for (_, counts) in operationRefCounts {
103
+ if ( counts. live > 0 ) { perform ( counts. op) }
104
+ }
105
+ }
106
+
107
+ static func onAllOperations( _ perform: ( LazyTensorOperation ) -> ( ) ) {
108
+ for (_, counts) in operationRefCounts { perform ( counts. op) }
109
+ }
110
+
111
+ public static func printRefCounts( ) {
112
+ let live = operationRefCounts. values. reduce ( 0 , { ( sum, element) in
113
+ return sum + ( element. live > 0 ? 1 : 0 )
114
+ } )
115
+ print ( " LazyTensorOperations: \( operationRefCounts. count) ( \( live) live) " )
45
116
}
46
117
47
118
static var _materializationCallback : ( String ) -> ( ) = { _ in }
@@ -85,19 +156,26 @@ class LazyTensorOperation: TensorOperation {
85
156
}
86
157
}
87
158
159
+ public static var liveOperations : Int = 0
160
+
88
161
init ( _id id: String ? , name: String , outputCount: Int ) {
89
162
self . name = name
90
163
self . inputs = [ ]
91
164
self . attrs = [ : ]
92
165
self . outputCount = outputCount
93
166
self . outputs = nil
94
167
self . id = id
168
+ LazyTensorOperation . liveOperations += 1
95
169
}
96
170
97
171
required convenience init ( _ name: String , _ outputCount: Int ) {
98
172
self . init ( _id: nil , name: name, outputCount: outputCount)
99
173
}
100
174
175
+ deinit {
176
+ LazyTensorOperation . liveOperations -= 1
177
+ }
178
+
101
179
func evaluate( ) -> [ LazyTensor ] {
102
180
return ( 0 ..< outputCount) . map {
103
181
LazyTensor ( _lazyLive: self , index: $0)
0 commit comments