@@ -20,55 +20,55 @@ extension TFETensorHandle {
20
20
}
21
21
22
22
/// Returns true if the underlying tensors are equal.
23
- static func areTensorsEqual ( _ lhs : TFETensorHandle , _ rhs : TFETensorHandle ) -> Bool {
24
- let lhsDtype = TFE_TensorHandleDataType ( lhs . _cTensorHandle)
25
- let rhsDtype = TFE_TensorHandleDataType ( rhs . _cTensorHandle)
23
+ func elementsEqual ( _ other : TFETensorHandle ) -> Bool {
24
+ let selfDtype = TFE_TensorHandleDataType ( self . _cTensorHandle)
25
+ let otherDtype = TFE_TensorHandleDataType ( other . _cTensorHandle)
26
26
precondition (
27
- lhsDtype == rhsDtype && lhsDtype != TF_VARIANT && lhsDtype != TF_RESOURCE,
27
+ selfDtype == otherDtype && selfDtype != TF_VARIANT && selfDtype != TF_RESOURCE,
28
28
" Datatypes of tensor handles don't match. " )
29
29
let op = TFE_Op ( " Equal " , 1 )
30
- op. updateAttribute ( " T " , TensorDataType ( lhsDtype ) )
31
- op. addInput ( lhs )
32
- op. addInput ( rhs )
30
+ op. updateAttribute ( " T " , TensorDataType ( selfDtype ) )
31
+ op. addInput ( self )
32
+ op. addInput ( other )
33
33
let result : Tensor < Bool > = op. execute ( Int ( 1 ) )
34
34
return result. scalars. allSatisfy { $0 }
35
35
}
36
36
}
37
37
38
38
extension LazyTensorHandle {
39
- static func areHandlesEquivalent ( _ lhs : LazyTensorHandle , _ rhs : LazyTensorHandle ) -> Bool {
40
- switch ( lhs . handle, rhs . handle) {
39
+ func isEquivalent ( to other : LazyTensorHandle ) -> Bool {
40
+ switch ( self . handle, other . handle) {
41
41
case let ( . concrete( x, _) , . concrete( y, _) ) :
42
42
return TFETensorHandle . areHandlesEquivalent ( x, y)
43
43
case let ( . symbolic( x, xi, _) , . symbolic( y, yi, _) ) :
44
- return ( xi == yi) && ( x. id == y. id)
44
+ return xi == yi && x. id == y. id
45
45
default : return false
46
46
}
47
47
}
48
48
}
49
49
50
- extension LazyTensorOperation {
50
+ extension LazyTensorOperation . Input {
51
51
/// Returns true if these inputs are equivalent when comparing lazy tensor traces.
52
- static func areInputsEquivalent ( _ lhs : Input , _ rhs : Input ) -> Bool {
53
- switch ( lhs , rhs ) {
52
+ func isEquivalent ( to other : LazyTensorOperation . Input ) -> Bool {
53
+ switch ( self , other ) {
54
54
case let ( . single( l) , . single( r) ) :
55
- return LazyTensorHandle . areHandlesEquivalent ( l , r)
55
+ return l . isEquivalent ( to : r)
56
56
case let ( . list( l) , . list( r) ) :
57
- return l. elementsEqual ( r, by: { LazyTensorHandle . areHandlesEquivalent ( $0 , $1) } )
57
+ return l. elementsEqual ( r, by: { $0 . isEquivalent ( to : $1) } )
58
58
default :
59
59
return false
60
60
}
61
61
}
62
+ }
62
63
64
+ extension LazyTensorOperation {
63
65
/// Returns true if these operations are equivalent when comparing lazy tensor traces.
64
- static func areEquivalent( _ lhs: LazyTensorOperation , _ rhs: LazyTensorOperation ) -> Bool {
65
- return ( lhs. name == rhs. name) &&
66
- ( lhs. outputCount == rhs. outputCount) &&
67
- ( lhs. deviceName == rhs. deviceName) &&
68
- lhs. inputs. elementsEqual (
69
- rhs. inputs,
70
- by: { LazyTensorOperation . areInputsEquivalent ( $0, $1) } ) &&
71
- ( lhs. attributes == rhs. attributes)
66
+ func isEquivalent( to other: LazyTensorOperation ) -> Bool {
67
+ return self . name == other. name &&
68
+ self . outputCount == other. outputCount &&
69
+ self . deviceName == other. deviceName &&
70
+ self . inputs. elementsEqual ( other. inputs, by: { $0. isEquivalent ( to: $1) } ) &&
71
+ self . attributes == other. attributes
72
72
}
73
73
}
74
74
@@ -100,21 +100,23 @@ func ==(_ lhs: LazyTensorOperation.Attribute, _ rhs: LazyTensorOperation.Attribu
100
100
}
101
101
}
102
102
103
- // TODO(https://bugs.swift.org/browse/ TF-693): This is not thread safe!
103
+ // TODO(TF-693): This is not thread safe!
104
104
struct LazyTensorTraceCache {
105
- // Cache from signature to traces that match signature.
105
+ /// Cache from signature to traces that match signature.
106
106
static private var cache : [ String : [ LazyTensorTrace ] ] = [ : ]
107
107
static func clearCache( ) { cache. removeAll ( ) }
108
108
109
- // Returns a `MaterializationTraceInfo` with possibly some constants promoted to inputs.
110
- static func traceWithPromotedConstants( _ traceInfo: MaterializationTraceInfo ) -> MaterializationTraceInfo {
109
+ /// Returns a `MaterializationTraceInfo` with possibly some constants promoted to inputs.
110
+ static func traceWithPromotedConstants(
111
+ _ traceInfo: MaterializationTraceInfo
112
+ ) -> MaterializationTraceInfo {
111
113
let trace = traceInfo. trace
112
114
guard var traces = cache [ trace. signature] else {
113
115
cache [ trace. signature] = [ trace]
114
116
return traceInfo
115
117
}
116
118
for cachedTrace in traces {
117
- if let promotedTrace = traceWithPromotedConstants ( traceInfo, cachedTrace) {
119
+ if let promotedTrace = traceInfo. withPromotedConstants ( cachedTrace : cachedTrace) {
118
120
debugLog ( " Promoted: \( promotedTrace) \n " )
119
121
return promotedTrace
120
122
}
@@ -123,23 +125,22 @@ struct LazyTensorTraceCache {
123
125
traces. append ( trace)
124
126
return traceInfo
125
127
}
128
+ }
126
129
127
- static private func traceWithPromotedConstants(
128
- _ traceInfo: MaterializationTraceInfo ,
129
- _ cachedTrace: LazyTensorTrace
130
- ) -> MaterializationTraceInfo ? {
131
- let currentTrace = traceInfo. trace
130
+ private extension MaterializationTraceInfo {
131
+ func withPromotedConstants( cachedTrace: LazyTensorTrace ) -> MaterializationTraceInfo ? {
132
+ let currentTrace = self . trace
132
133
if currentTrace. operations. count != cachedTrace. operations. count { return nil }
133
134
var promotableConstants : [ ( Int , TFETensorHandle ) ] = [ ]
134
135
for (i, current) in currentTrace. operations. enumerated ( ) {
135
136
let cached = cachedTrace. operations [ i]
136
- if let ( currentTensor, cachedTensor) = promotableConstant ( current, cached) {
137
- if TFETensorHandle . areTensorsEqual ( currentTensor , cachedTensor) { continue }
137
+ if let ( currentTensor, cachedTensor) = Self . promotableConstants ( current, cached) {
138
+ if currentTensor . elementsEqual ( cachedTensor) { continue }
138
139
promotableConstants. append ( ( i, currentTensor) )
139
140
continue
140
141
}
141
142
// TODO: we might avoid running the following check based on results of promotableConstant
142
- if LazyTensorOperation . areEquivalent ( current , cached) { continue }
143
+ if current . isEquivalent ( to : cached) { continue }
143
144
return nil
144
145
}
145
146
@@ -157,26 +158,27 @@ struct LazyTensorTraceCache {
157
158
operations: newOperations,
158
159
outputs: currentTrace. outputs)
159
160
return MaterializationTraceInfo (
160
- lazyOperations: traceInfo . lazyOperations,
161
+ lazyOperations: self . lazyOperations,
161
162
trace: newTrace,
162
- concreteInputs: traceInfo . concreteInputs + newConcreteInputs)
163
+ concreteInputs: self . concreteInputs + newConcreteInputs)
163
164
}
164
165
165
166
/// If `current` and `cached` are compatible constants, returns the constant tensors.
166
- static private func promotableConstant (
167
+ static private func promotableConstants (
167
168
_ current: LazyTensorOperation ,
168
169
_ cached: LazyTensorOperation
169
170
) -> ( TFETensorHandle , TFETensorHandle ) ? {
170
- if ( current. name != " Const " || cached. name != " Const " ) { return nil }
171
+ if current. name != " Const " || cached. name != " Const " { return nil }
171
172
let currentValue = current. attributes [ " value " ] !
172
173
let cachedValue = cached. attributes [ " value " ] !
173
- guard case let . constTensor( currentTensor) = currentValue else { return nil }
174
- guard case let . constTensor( cachedTensor) = cachedValue else { return nil }
174
+ guard case let . constTensor( currentTensor) = currentValue,
175
+ case let . constTensor( cachedTensor) = cachedValue
176
+ else { return nil }
175
177
let currentDtype = TFE_TensorHandleDataType ( currentTensor. _cTensorHandle)
176
178
let cachedDtype = TFE_TensorHandleDataType ( cachedTensor. _cTensorHandle)
177
179
if currentDtype == TF_VARIANT || currentDtype == TF_RESOURCE { return nil }
178
180
if cachedDtype == TF_VARIANT || cachedDtype == TF_RESOURCE { return nil }
179
- return ( currentTensor. shape == cachedTensor. shape) && ( currentDtype == cachedDtype)
181
+ return currentTensor. shape == cachedTensor. shape && currentDtype == cachedDtype
180
182
? ( currentTensor, cachedTensor)
181
183
: nil
182
184
}
0 commit comments