@@ -32,8 +32,10 @@ public protocol TensorArrayProtocol {
32
32
33
33
var _tensorHandleCount : Int32 { get }
34
34
var _typeList : [ TensorDataType ] { get }
35
+ var _tensorHandles : [ _AnyTensorHandle ] { get }
35
36
36
37
init ( _owning tensorHandles: UnsafePointer < CTensorHandle > ? , count: Int )
38
+ init < C: RandomAccessCollection > ( _handles: C ) where C. Element == _AnyTensorHandle
37
39
}
38
40
39
41
/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`.
@@ -88,13 +90,21 @@ extension TensorHandle: TensorGroup {
88
90
return [ Scalar . tensorFlowDataType]
89
91
}
90
92
93
+ public var _tensorHandles : [ _AnyTensorHandle ] { [ self . handle] }
94
+
91
95
public func _unpackTensorHandles( into address: UnsafeMutablePointer < CTensorHandle > ? ) {
92
96
address!. initialize ( to: _cTensorHandle)
93
97
}
94
98
95
99
public init ( _owning tensorHandles: UnsafePointer < CTensorHandle > ? ) {
96
100
self . init ( _owning: tensorHandles!. pointee)
97
101
}
102
+
103
+ public init < C: RandomAccessCollection > (
104
+ _handles: C ) where C. Element == _AnyTensorHandle {
105
+ precondition ( _handles. count == 1 )
106
+ self . init ( handle: _handles [ _handles. startIndex] )
107
+ }
98
108
}
99
109
100
110
extension ResourceHandle : TensorGroup {
@@ -108,13 +118,21 @@ extension ResourceHandle: TensorGroup {
108
118
return [ TensorDataType ( TF_RESOURCE) ]
109
119
}
110
120
121
+ public var _tensorHandles : [ _AnyTensorHandle ] { [ self . handle] }
122
+
111
123
public func _unpackTensorHandles( into address: UnsafeMutablePointer < CTensorHandle > ? ) {
112
124
address!. initialize ( to: _cTensorHandle)
113
125
}
114
126
115
127
public init ( _owning tensorHandles: UnsafePointer < CTensorHandle > ? ) {
116
128
self . init ( owning: tensorHandles!. pointee)
117
129
}
130
+
131
+ public init < C: RandomAccessCollection > (
132
+ _handles: C ) where C. Element == _AnyTensorHandle {
133
+ precondition ( _handles. count == 1 )
134
+ self . init ( handle: _handles [ _handles. startIndex] )
135
+ }
118
136
}
119
137
120
138
extension VariantHandle : TensorGroup {
@@ -128,13 +146,21 @@ extension VariantHandle: TensorGroup {
128
146
return [ TensorDataType ( TF_VARIANT) ]
129
147
}
130
148
149
+ public var _tensorHandles : [ _AnyTensorHandle ] { [ self . handle] }
150
+
131
151
public func _unpackTensorHandles( into address: UnsafeMutablePointer < CTensorHandle > ? ) {
132
152
address!. initialize ( to: _cTensorHandle)
133
153
}
134
154
135
155
public init ( _owning tensorHandles: UnsafePointer < CTensorHandle > ? ) {
136
156
self . init ( owning: tensorHandles!. pointee)
137
157
}
158
+
159
+ public init < C: RandomAccessCollection > (
160
+ _handles: C ) where C. Element == _AnyTensorHandle {
161
+ precondition ( _handles. count == 1 )
162
+ self . init ( handle: _handles [ _handles. startIndex] )
163
+ }
138
164
}
139
165
140
166
extension Tensor : TensorGroup {
@@ -152,9 +178,17 @@ extension Tensor: TensorGroup {
152
178
address!. initialize ( to: handle. _cTensorHandle)
153
179
}
154
180
181
+ public var _tensorHandles : [ _AnyTensorHandle ] { [ self . handle. handle] }
182
+
155
183
public init ( _owning tensorHandles: UnsafePointer < CTensorHandle > ? ) {
156
184
self . init ( handle: TensorHandle ( _owning: tensorHandles!. pointee) )
157
185
}
186
+
187
+ public init < C: RandomAccessCollection > (
188
+ _handles: C ) where C. Element == _AnyTensorHandle {
189
+ precondition ( _handles. count == 1 )
190
+ self . init ( handle: TensorHandle ( handle: _handles [ _handles. startIndex] ) )
191
+ }
158
192
}
159
193
160
194
extension _TensorElementLiteral : TensorGroup {
@@ -168,13 +202,21 @@ extension _TensorElementLiteral: TensorGroup {
168
202
return [ Scalar . tensorFlowDataType]
169
203
}
170
204
205
+ public var _tensorHandles : [ _AnyTensorHandle ] { [ self . handle. handle] }
206
+
171
207
public func _unpackTensorHandles( into address: UnsafeMutablePointer < CTensorHandle > ? ) {
172
208
address!. initialize ( to: handle. _cTensorHandle)
173
209
}
174
210
175
211
public init ( _owning tensorHandles: UnsafePointer < CTensorHandle > ? ) {
176
212
self . init ( handle: TensorHandle ( _owning: tensorHandles!. pointee) )
177
213
}
214
+
215
+ public init < C: RandomAccessCollection > (
216
+ _handles: C ) where C. Element == _AnyTensorHandle {
217
+ precondition ( _handles. count == 1 )
218
+ self . init ( handle: TensorHandle ( handle: _handles [ _handles. startIndex] ) )
219
+ }
178
220
}
179
221
180
222
extension StringTensor : TensorGroup {
@@ -192,9 +234,17 @@ extension StringTensor: TensorGroup {
192
234
address!. initialize ( to: handle. _cTensorHandle)
193
235
}
194
236
237
+ public var _tensorHandles : [ _AnyTensorHandle ] { [ self . handle. handle] }
238
+
195
239
public init ( _owning tensorHandles: UnsafePointer < CTensorHandle > ? ) {
196
240
self . init ( handle: TensorHandle ( _owning: tensorHandles!. pointee) )
197
241
}
242
+
243
+ public init < C: RandomAccessCollection > (
244
+ _handles: C ) where C. Element == _AnyTensorHandle {
245
+ precondition ( _handles. count == 1 )
246
+ self . init ( handle: TensorHandle ( handle: _handles [ _handles. startIndex] ) )
247
+ }
198
248
}
199
249
200
250
extension Array : TensorArrayProtocol where Element: TensorGroup {
@@ -216,10 +266,31 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
216
266
count: Int ( count) ) . joined ( ) )
217
267
}
218
268
269
+ public var _tensorHandles : ( [ _AnyTensorHandle ] ) {
270
+ var result : [ _AnyTensorHandle ] = [ ]
271
+ result. reserveCapacity ( Int ( self . _tensorHandleCount) )
272
+ for elem in self {
273
+ result += elem. _tensorHandles
274
+ }
275
+ return result
276
+ }
277
+
219
278
public init ( _owning tensorHandles: UnsafePointer < CTensorHandle > ? , count: Int ) {
220
279
let size = count / Int( Element . _tensorHandleCount)
221
280
self = Array ( ( 0 ..< size) . map { Element . init (
222
281
_owning: tensorHandles? . advanced ( by: $0 * Int( Element . _tensorHandleCount) ) )
223
282
} )
224
283
}
284
+
285
+ public init < C: RandomAccessCollection > (
286
+ _handles: C ) where C. Element == _AnyTensorHandle {
287
+ let size = _handles. count / Int( Element . _tensorHandleCount)
288
+ self = ( 0 ..< size) . map {
289
+ let start = _handles. index (
290
+ _handles. startIndex, offsetBy: $0 * Int( Element . _tensorHandleCount) )
291
+ let end = _handles. index (
292
+ start, offsetBy: Int ( Element . _tensorHandleCount) )
293
+ return Element . init ( _handles: _handles [ start..< end] )
294
+ }
295
+ }
225
296
}
0 commit comments