@@ -153,13 +153,15 @@ struct UltraNested<T: TensorGroup & Equatable, V: TensorGroup & Equatable>
153
153
}
154
154
}
155
155
156
- func copy< T> ( of handle: TensorHandle < T > ) -> TFETensorHandle {
157
- let status = TF_NewStatus ( )
158
- let result = TFETensorHandle ( _owning: TFE_TensorHandleCopySharingTensor (
159
- handle. _cTensorHandle, status) !)
160
- XCTAssertEqual ( TF_GetCode ( status) , TF_OK)
161
- TF_DeleteStatus ( status)
162
- return result
156
+ extension TensorHandle {
157
+ func makeCopy( ) -> TFETensorHandle {
158
+ let status = TF_NewStatus ( )
159
+ let result = TFETensorHandle (
160
+ _owning: TFE_TensorHandleCopySharingTensor ( handle. _cTensorHandle, status) !)
161
+ XCTAssertEqual ( TF_GetCode ( status) , TF_OK)
162
+ TF_DeleteStatus ( status)
163
+ return result
164
+ }
163
165
}
164
166
165
167
final class TensorGroupTests : XCTestCase {
@@ -178,8 +180,8 @@ final class TensorGroupTests: XCTestCase {
178
180
let b = Tensor < Float > ( 0.1 )
179
181
let simple = Simple ( w: w, b: b)
180
182
181
- let wHandle = copy ( of : w. handle)
182
- let bHandle = copy ( of : b. handle)
183
+ let wHandle = w. handle. makeCopy ( )
184
+ let bHandle = b. handle. makeCopy ( )
183
185
184
186
let expectedSimple = Simple ( _handles: [ wHandle, bHandle] )
185
187
@@ -197,8 +199,8 @@ final class TensorGroupTests: XCTestCase {
197
199
let int = Tensor < Int32 > ( 1 )
198
200
let mixed = Mixed ( float: float, int: int)
199
201
200
- let floatHandle = copy ( of : float. handle)
201
- let intHandle = copy ( of : int. handle)
202
+ let floatHandle = float. handle. makeCopy ( )
203
+ let intHandle = int. handle. makeCopy ( )
202
204
203
205
let expectedMixed = Mixed ( _handles: [ floatHandle, intHandle] )
204
206
@@ -220,10 +222,10 @@ final class TensorGroupTests: XCTestCase {
220
222
let mixed = Mixed ( float: float, int: int)
221
223
let nested = Nested ( simple: simple, mixed: mixed)
222
224
223
- let wHandle = copy ( of : w. handle)
224
- let bHandle = copy ( of : b. handle)
225
- let floatHandle = copy ( of : float. handle)
226
- let intHandle = copy ( of : int. handle)
225
+ let wHandle = w. handle. makeCopy ( )
226
+ let bHandle = b. handle. makeCopy ( )
227
+ let floatHandle = float. handle. makeCopy ( )
228
+ let intHandle = int. handle. makeCopy ( )
227
229
228
230
let expectedNested = Nested (
229
231
_handles: [ wHandle, bHandle, floatHandle, intHandle] )
@@ -247,10 +249,10 @@ final class TensorGroupTests: XCTestCase {
247
249
let mixed = Mixed ( float: float, int: int)
248
250
let generic = Generic ( t: simple, u: mixed)
249
251
250
- let wHandle = copy ( of : w. handle)
251
- let bHandle = copy ( of : b. handle)
252
- let floatHandle = copy ( of : float. handle)
253
- let intHandle = copy ( of : int. handle)
252
+ let wHandle = w. handle. makeCopy ( )
253
+ let bHandle = b. handle. makeCopy ( )
254
+ let floatHandle = float. handle. makeCopy ( )
255
+ let intHandle = int. handle. makeCopy ( )
254
256
255
257
let expectedGeneric = Generic < Simple , Mixed > (
256
258
_handles: [ wHandle, bHandle, floatHandle, intHandle] )
@@ -284,14 +286,14 @@ final class TensorGroupTests: XCTestCase {
284
286
let genericMS = Generic < Mixed , Simple > ( t: mixed, u: simple)
285
287
let generic = UltraNested ( a: genericSM, b: genericMS)
286
288
287
- let wHandle1 = copy ( of : w. handle)
288
- let wHandle2 = copy ( of : w. handle)
289
- let bHandle1 = copy ( of : b. handle)
290
- let bHandle2 = copy ( of : b. handle)
291
- let floatHandle1 = copy ( of : float. handle)
292
- let floatHandle2 = copy ( of : float. handle)
293
- let intHandle1 = copy ( of : int. handle)
294
- let intHandle2 = copy ( of : int. handle)
289
+ let wHandle1 = w. handle. makeCopy ( )
290
+ let wHandle2 = w. handle. makeCopy ( )
291
+ let bHandle1 = b. handle. makeCopy ( )
292
+ let bHandle2 = b. handle. makeCopy ( )
293
+ let floatHandle1 = float. handle. makeCopy ( )
294
+ let floatHandle2 = float. handle. makeCopy ( )
295
+ let intHandle1 = int. handle. makeCopy ( )
296
+ let intHandle2 = int. handle. makeCopy ( )
295
297
296
298
let expectedGeneric = UltraNested < Simple , Mixed > (
297
299
_handles: [ wHandle1, bHandle1, floatHandle1, intHandle1,
0 commit comments