@@ -50,80 +50,40 @@ class LazyTensorHandle: _AnyTensorHandle {
50
50
precondition (
51
51
index < op. outputCount, " Symbolic Tensor Index is out-of-bounds " )
52
52
handle = Handle . symbolic ( op, index: index, isLive: false )
53
- LazyTensorHandle . incrementRefCount ( op, isLive: false )
53
+ Self . operationsTracker . incrementRefCount ( op, isLive: false )
54
54
}
55
55
56
56
init ( _lazyLive op: LazyTensorOperation , index: Int ) {
57
57
precondition (
58
58
index < op. outputCount, " Symbolic Tensor Index is out-of-bounds " )
59
59
handle = Handle . symbolic ( op, index: index, isLive: true )
60
- LazyTensorHandle . incrementRefCount ( op, isLive: true )
60
+ Self . operationsTracker . incrementRefCount ( op, isLive: true )
61
61
}
62
62
63
63
deinit {
64
64
if case let . symbolic( op, _, isLive) = handle {
65
- LazyTensorHandle . decrementRefCount ( op, isLive: isLive)
65
+ Self . operationsTracker . decrementRefCount ( op, isLive: isLive)
66
66
}
67
67
}
68
-
68
+
69
69
// Liveness tracking for LazyTensorOperations
70
70
//
71
- struct LazyTensorOperationRefCounts {
72
- let op : LazyTensorOperation
73
- let liveRefCount : Int
74
- let allRefCount : Int
75
- }
76
-
77
- private static var operationRefCounts : [
78
- ObjectIdentifier : LazyTensorOperationRefCounts ] = [ : ]
79
-
80
- static func incrementRefCount( _ op: LazyTensorOperation , isLive: Bool ) {
81
- let opID = ObjectIdentifier ( op)
82
- if let counts = operationRefCounts [ opID] {
83
- operationRefCounts [ opID] = LazyTensorOperationRefCounts (
84
- op: op,
85
- liveRefCount: counts. liveRefCount + ( isLive ? 1 : 0 ) ,
86
- allRefCount: counts. allRefCount + 1 )
87
- } else {
88
- operationRefCounts [ opID] = LazyTensorOperationRefCounts (
89
- op: op, liveRefCount: isLive ? 1 : 0 , allRefCount: 1 )
90
- }
91
- }
92
-
93
- static func decrementRefCount( _ op: LazyTensorOperation , isLive: Bool ) {
94
- let opID = ObjectIdentifier ( op)
95
- if let counts = operationRefCounts [ opID] {
96
- if counts. allRefCount > 1 {
97
- operationRefCounts [ opID] = LazyTensorOperationRefCounts (
98
- op: op,
99
- liveRefCount: counts. liveRefCount - ( isLive ? 1 : 0 ) ,
100
- allRefCount: counts. allRefCount - 1 )
101
- } else {
102
- operationRefCounts. removeValue ( forKey: opID)
103
- }
104
- }
105
- }
71
+ private static var operationsTracker = LazyTensorOperationsTracker ( )
106
72
107
73
static func isLive( _ op: LazyTensorOperation ) -> Bool {
108
- let opID = ObjectIdentifier ( op)
109
- if let counts = operationRefCounts [ opID] {
110
- return counts. liveRefCount > 0
111
- }
112
- return false
74
+ return operationsTracker. isLive ( op)
113
75
}
114
76
115
77
static func forEachLiveOperation(
116
78
_ perform: ( LazyTensorOperation ) throws -> Void
117
79
) rethrows -> Void {
118
- for (_, counts) in operationRefCounts where counts. liveRefCount > 0 {
119
- try perform ( counts. op)
120
- }
80
+ try operationsTracker. forEachLiveOperation ( perform)
121
81
}
122
82
123
83
static func forEachOperation(
124
84
_ perform: ( LazyTensorOperation ) throws -> Void
125
85
) rethrows -> Void {
126
- for (_ , counts ) in operationRefCounts { try perform ( counts . op ) }
86
+ try operationsTracker . forEachOperation ( perform )
127
87
}
128
88
129
89
@usableFromInline
0 commit comments