@@ -48,14 +48,14 @@ final class BasicOperatorTests: XCTestCase {
48
48
let array0D = element0D. array
49
49
50
50
/// Test shapes
51
- XCTAssertEqual ( [ 4 , 5 ] , array2D . shape )
52
- XCTAssertEqual ( [ 5 ] , array1D. shape)
53
- XCTAssertEqual ( [ ] , array0D. shape)
51
+ XCTAssertEqual ( array2D . shape , [ 4 , 5 ] )
52
+ XCTAssertEqual ( array1D. shape, [ 5 ] )
53
+ XCTAssertEqual ( array0D. shape, [ ] )
54
54
55
55
/// Test scalars
56
- XCTAssertEqual ( Array ( stride ( from: 40.0 , to: 60 , by: 1 ) ) , array2D . scalars )
57
- XCTAssertEqual ( Array ( stride ( from: 35.0 , to: 40 , by: 1 ) ) , array1D . scalars )
58
- XCTAssertEqual ( [ 43 ] , array0D. scalars)
56
+ XCTAssertEqual ( array2D . scalars , Array ( stride ( from: 40.0 , to: 60 , by: 1 ) ) )
57
+ XCTAssertEqual ( array1D . scalars , Array ( stride ( from: 35.0 , to: 40 , by: 1 ) ) )
58
+ XCTAssertEqual ( array0D. scalars, [ 43 ] )
59
59
}
60
60
61
61
func testElementIndexingAssignment( ) {
@@ -76,14 +76,14 @@ final class BasicOperatorTests: XCTestCase {
76
76
let array0D = element0D. array
77
77
78
78
/// Test shapes
79
- XCTAssertEqual ( [ 4 , 5 ] , array2D . shape )
80
- XCTAssertEqual ( [ 5 ] , array1D. shape)
81
- XCTAssertEqual ( [ ] , array0D. shape)
79
+ XCTAssertEqual ( array2D . shape , [ 4 , 5 ] )
80
+ XCTAssertEqual ( array1D. shape, [ 5 ] )
81
+ XCTAssertEqual ( array0D. shape, [ ] )
82
82
83
83
/// Test scalars
84
- XCTAssertEqual ( Array ( stride ( from: 20.0 , to: 40 , by: 1 ) ) , array2D . scalars )
85
- XCTAssertEqual ( Array ( stride ( from: 35.0 , to: 40 , by: 1 ) ) , array1D . scalars )
86
- XCTAssertEqual ( [ 23 ] , array0D. scalars)
84
+ XCTAssertEqual ( array2D . scalars , Array ( stride ( from: 20.0 , to: 40 , by: 1 ) ) )
85
+ XCTAssertEqual ( array1D . scalars , Array ( stride ( from: 35.0 , to: 40 , by: 1 ) ) )
86
+ XCTAssertEqual ( array0D. scalars, [ 23 ] )
87
87
}
88
88
89
89
func testNestedElementIndexing( ) {
@@ -99,12 +99,12 @@ final class BasicOperatorTests: XCTestCase {
99
99
let array0D = element0D. array
100
100
101
101
/// Test shapes
102
- XCTAssertEqual ( [ 5 ] , array1D. shape)
103
- XCTAssertEqual ( [ ] , array0D. shape)
102
+ XCTAssertEqual ( array1D. shape, [ 5 ] )
103
+ XCTAssertEqual ( array0D. shape, [ ] )
104
104
105
105
/// Test scalars
106
- XCTAssertEqual ( Array ( stride ( from: 35.0 , to: 40 , by: 1 ) ) , array1D . scalars )
107
- XCTAssertEqual ( [ 43 ] , array0D. scalars)
106
+ XCTAssertEqual ( array1D . scalars , Array ( stride ( from: 35.0 , to: 40 , by: 1 ) ) )
107
+ XCTAssertEqual ( array0D. scalars, [ 43 ] )
108
108
}
109
109
110
110
func testSliceIndexing( ) {
@@ -123,14 +123,14 @@ final class BasicOperatorTests: XCTestCase {
123
123
let array1D = slice1D. array
124
124
125
125
/// Test shapes
126
- XCTAssertEqual ( [ 1 , 4 , 5 ] , array3D . shape )
127
- XCTAssertEqual ( [ 2 , 5 ] , array2D . shape )
128
- XCTAssertEqual ( [ 2 ] , array1D. shape)
126
+ XCTAssertEqual ( array3D . shape , [ 1 , 4 , 5 ] )
127
+ XCTAssertEqual ( array2D . shape , [ 2 , 5 ] )
128
+ XCTAssertEqual ( array1D. shape, [ 2 ] )
129
129
130
130
/// Test scalars
131
- XCTAssertEqual ( Array ( stride ( from: 40.0 , to: 60 , by: 1 ) ) , array3D . scalars )
132
- XCTAssertEqual ( Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) , array2D . scalars )
133
- XCTAssertEqual ( Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) , array1D . scalars )
131
+ XCTAssertEqual ( array3D . scalars , Array ( stride ( from: 40.0 , to: 60 , by: 1 ) ) )
132
+ XCTAssertEqual ( array2D . scalars , Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) )
133
+ XCTAssertEqual ( array1D . scalars , Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) )
134
134
}
135
135
136
136
func testSliceIndexingAssignment( ) {
@@ -151,14 +151,14 @@ final class BasicOperatorTests: XCTestCase {
151
151
let array1D = slice1D. array
152
152
153
153
/// Test shapes
154
- XCTAssertEqual ( [ 1 , 4 , 5 ] , array3D . shape )
155
- XCTAssertEqual ( [ 2 , 5 ] , array2D . shape )
156
- XCTAssertEqual ( [ 2 ] , array1D. shape)
154
+ XCTAssertEqual ( array3D . shape , [ 1 , 4 , 5 ] )
155
+ XCTAssertEqual ( array2D . shape , [ 2 , 5 ] )
156
+ XCTAssertEqual ( array1D. shape, [ 2 ] )
157
157
158
158
/// Test scalars
159
- XCTAssertEqual ( Array ( stride ( from: 20.0 , to: 40 , by: 1 ) ) , array3D . scalars )
160
- XCTAssertEqual ( Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) , array2D . scalars )
161
- XCTAssertEqual ( Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) , array1D . scalars )
159
+ XCTAssertEqual ( array3D . scalars , Array ( stride ( from: 20.0 , to: 40 , by: 1 ) ) )
160
+ XCTAssertEqual ( array2D . scalars , Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) )
161
+ XCTAssertEqual ( array1D . scalars , Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) )
162
162
}
163
163
164
164
func testEllipsisIndexing( ) {
@@ -179,14 +179,14 @@ final class BasicOperatorTests: XCTestCase {
179
179
let array1D = slice1D. array
180
180
181
181
/// Test shapes
182
- XCTAssertEqual ( [ 1 , 4 , 5 ] , array3D . shape )
183
- XCTAssertEqual ( [ 2 , 5 ] , array2D . shape )
184
- XCTAssertEqual ( [ 2 ] , array1D. shape)
182
+ XCTAssertEqual ( array3D . shape , [ 1 , 4 , 5 ] )
183
+ XCTAssertEqual ( array2D . shape , [ 2 , 5 ] )
184
+ XCTAssertEqual ( array1D. shape, [ 2 ] )
185
185
186
186
/// Test scalars
187
- XCTAssertEqual ( Array ( stride ( from: 20.0 , to: 40 , by: 1 ) ) , array3D . scalars )
188
- XCTAssertEqual ( Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) , array2D . scalars )
189
- XCTAssertEqual ( Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) , array1D . scalars )
187
+ XCTAssertEqual ( array3D . scalars , Array ( stride ( from: 20.0 , to: 40 , by: 1 ) ) )
188
+ XCTAssertEqual ( array2D . scalars , Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) )
189
+ XCTAssertEqual ( array1D . scalars , Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) )
190
190
}
191
191
192
192
func testNewAxisIndexing( ) {
@@ -207,14 +207,14 @@ final class BasicOperatorTests: XCTestCase {
207
207
let array1D = slice1D. array
208
208
209
209
/// Test shapes
210
- XCTAssertEqual ( [ 1 , 1 , 4 , 5 ] , array3D . shape )
211
- XCTAssertEqual ( [ 1 , 2 , 5 ] , array2D . shape )
212
- XCTAssertEqual ( [ 1 , 2 , 1 ] , array1D . shape )
210
+ XCTAssertEqual ( array3D . shape , [ 1 , 1 , 4 , 5 ] )
211
+ XCTAssertEqual ( array2D . shape , [ 1 , 2 , 5 ] )
212
+ XCTAssertEqual ( array1D . shape , [ 1 , 2 , 1 ] )
213
213
214
214
/// Test scalars
215
- XCTAssertEqual ( Array ( stride ( from: 40.0 , to: 60 , by: 1 ) ) , array3D . scalars )
216
- XCTAssertEqual ( Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) , array2D . scalars )
217
- XCTAssertEqual ( Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) , array1D . scalars )
215
+ XCTAssertEqual ( array3D . scalars , Array ( stride ( from: 40.0 , to: 60 , by: 1 ) ) )
216
+ XCTAssertEqual ( array2D . scalars , Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) )
217
+ XCTAssertEqual ( array1D . scalars , Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) )
218
218
}
219
219
220
220
func testSqueezeAxisIndexing( ) {
@@ -237,14 +237,14 @@ final class BasicOperatorTests: XCTestCase {
237
237
let array1D = slice1D. array
238
238
239
239
/// Test shapes
240
- XCTAssertEqual ( [ 4 , 5 ] , array3D . shape )
241
- XCTAssertEqual ( [ 2 , 5 ] , array2D . shape )
242
- XCTAssertEqual ( [ 2 ] , array1D. shape)
240
+ XCTAssertEqual ( array3D . shape , [ 4 , 5 ] )
241
+ XCTAssertEqual ( array2D . shape , [ 2 , 5 ] )
242
+ XCTAssertEqual ( array1D. shape, [ 2 ] )
243
243
244
244
/// Test scalars
245
- XCTAssertEqual ( Array ( stride ( from: 40.0 , to: 60 , by: 1 ) ) , array3D . scalars )
246
- XCTAssertEqual ( Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) , array2D . scalars )
247
- XCTAssertEqual ( Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) , array1D . scalars )
245
+ XCTAssertEqual ( array3D . scalars , Array ( stride ( from: 40.0 , to: 60 , by: 1 ) ) )
246
+ XCTAssertEqual ( array2D . scalars , Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) )
247
+ XCTAssertEqual ( array1D . scalars , Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) )
248
248
}
249
249
250
250
func testStridedSliceIndexing( ) {
@@ -263,16 +263,17 @@ final class BasicOperatorTests: XCTestCase {
263
263
let array1D = slice1D. array
264
264
265
265
/// Test shapes
266
- XCTAssertEqual ( [ 1 , 4 , 5 ] , array3D . shape )
267
- XCTAssertEqual ( [ 2 , 5 ] , array2D . shape )
268
- XCTAssertEqual ( [ 2 ] , array1D. shape)
266
+ XCTAssertEqual ( array3D . shape , [ 1 , 4 , 5 ] )
267
+ XCTAssertEqual ( array2D . shape , [ 2 , 5 ] )
268
+ XCTAssertEqual ( array1D. shape, [ 2 ] )
269
269
270
270
/// Test scalars
271
- XCTAssertEqual ( Array ( stride ( from: 40.0 , to: 60 , by: 1 ) ) , array3D . scalars )
271
+ XCTAssertEqual ( array3D . scalars , Array ( stride ( from: 40.0 , to: 60 , by: 1 ) ) )
272
272
XCTAssertEqual (
273
+ array2D. scalars,
273
274
Array ( stride ( from: 20.0 , to: 25 , by: 1 ) ) +
274
- Array( stride ( from: 30.0 , to: 35 , by: 1 ) ) , array2D . scalars )
275
- XCTAssertEqual ( Array ( stride ( from: 1.0 , to: 5 , by: 2 ) ) , array1D . scalars )
275
+ Array( stride ( from: 30.0 , to: 35 , by: 1 ) ) )
276
+ XCTAssertEqual ( array1D . scalars , Array ( stride ( from: 1.0 , to: 5 , by: 2 ) ) )
276
277
}
277
278
278
279
func testStridedSliceIndexingAssignment( ) {
@@ -291,28 +292,28 @@ final class BasicOperatorTests: XCTestCase {
291
292
let array3D = slice3D. array
292
293
let array2D = slice2D. array
293
294
let array1D = slice1D. array
294
-
295
+
295
296
/// Test shapes
296
- XCTAssertEqual ( [ 1 , 4 , 5 ] , array3D . shape )
297
- XCTAssertEqual ( [ 2 , 5 ] , array2D . shape )
298
- XCTAssertEqual ( [ 2 ] , array1D. shape)
297
+ XCTAssertEqual ( array3D . shape , [ 1 , 4 , 5 ] )
298
+ XCTAssertEqual ( array2D . shape , [ 2 , 5 ] )
299
+ XCTAssertEqual ( array1D. shape, [ 2 ] )
299
300
300
301
/// Test scalars
301
- XCTAssertEqual (
302
- Array ( stride ( from: 20.0 , to: 30 , by: 2 ) ) +
303
- Array ( stride ( from: 45.0 , to: 50 , by: 1 ) ) +
304
- Array ( stride ( from: 30.0 , to: 40 , by: 2 ) ) +
305
- Array ( stride ( from: 55.0 , to: 60 , by: 1 ) ) , array3D . scalars )
306
- XCTAssertEqual ( Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) , array2D . scalars )
307
- XCTAssertEqual ( Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) , array1D . scalars )
302
+ XCTAssertEqual ( array3D . scalars ,
303
+ [ Float ] ( stride ( from: 20.0 , to: 30 , by: 2 ) ) +
304
+ [ Float ] ( stride ( from: 45.0 , to: 50 , by: 1 ) ) +
305
+ [ Float ] ( stride ( from: 30.0 , to: 40 , by: 2 ) ) +
306
+ [ Float ] ( stride ( from: 55.0 , to: 60 , by: 1 ) ) )
307
+ XCTAssertEqual ( array2D . scalars , Array ( stride ( from: 20.0 , to: 30 , by: 1 ) ) )
308
+ XCTAssertEqual ( array1D . scalars , Array ( stride ( from: 3.0 , to: 5 , by: 1 ) ) )
308
309
}
309
310
310
311
func testWholeTensorSlicing( ) {
311
312
let t : Tensor < Int32 > = [ [ [ 1 , 1 , 1 ] , [ 2 , 2 , 2 ] ] ,
312
313
[ [ 3 , 3 , 3 ] , [ 4 , 4 , 4 ] ] ,
313
314
[ [ 5 , 5 , 5 ] , [ 6 , 6 , 6 ] ] ]
314
315
let slice2 = t. slice ( lowerBounds: [ 1 , 0 , 0 ] , upperBounds: [ 2 , 1 , 3 ] )
315
- XCTAssertEqual ( ShapedArray ( shape: [ 1 , 1 , 3 ] , scalars: [ 3 , 3 , 3 ] ) , slice2 . array )
316
+ XCTAssertEqual ( slice2 . array , ShapedArray ( shape: [ 1 , 1 , 3 ] , scalars: [ 3 , 3 , 3 ] ) )
316
317
}
317
318
318
319
func testAdvancedIndexing( ) {
@@ -326,10 +327,10 @@ final class BasicOperatorTests: XCTestCase {
326
327
let array2D = element2D. array
327
328
328
329
// Test shape
329
- XCTAssertEqual ( [ 2 , 2 ] , array2D . shape )
330
+ XCTAssertEqual ( array2D . shape , [ 2 , 2 ] )
330
331
331
332
// Test scalars
332
- XCTAssertEqual ( Array ( [ 23.0 , 24.0 , 43.0 , 44.0 ] ) , array2D . scalars )
333
+ XCTAssertEqual ( array2D . scalars , Array ( [ 23.0 , 24.0 , 43.0 , 44.0 ] ) )
333
334
}
334
335
335
336
func testConcatenation( ) {
@@ -340,11 +341,11 @@ final class BasicOperatorTests: XCTestCase {
340
341
let concatenated = t1 ++ t2
341
342
let concatenated0 = t1. concatenated ( with: t2)
342
343
let concatenated1 = t1. concatenated ( with: t2, alongAxis: 1 )
343
- XCTAssertEqual ( ShapedArray ( shape: [ 4 , 3 ] , scalars: Array ( 0 ..< 12 ) ) , concatenated . array )
344
- XCTAssertEqual ( ShapedArray ( shape: [ 4 , 3 ] , scalars: Array ( 0 ..< 12 ) ) , concatenated0 . array )
344
+ XCTAssertEqual ( concatenated . array , ShapedArray ( shape: [ 4 , 3 ] , scalars: Array ( 0 ..< 12 ) ) )
345
+ XCTAssertEqual ( concatenated0 . array , ShapedArray ( shape: [ 4 , 3 ] , scalars: Array ( 0 ..< 12 ) ) )
345
346
XCTAssertEqual (
346
- ShapedArray ( shape : [ 2 , 6 ] , scalars : [ 0 , 1 , 2 , 6 , 7 , 8 , 3 , 4 , 5 , 9 , 10 , 11 ] ) ,
347
- concatenated1 . array )
347
+ concatenated1 . array ,
348
+ ShapedArray ( shape : [ 2 , 6 ] , scalars : [ 0 , 1 , 2 , 6 , 7 , 8 , 3 , 4 , 5 , 9 , 10 , 11 ] ) )
348
349
}
349
350
350
351
func testVJPConcatenation( ) {
@@ -358,8 +359,8 @@ final class BasicOperatorTests: XCTestCase {
358
359
return ( ( a1 * a) ++ ( b1 * b) ) . sum ( )
359
360
}
360
361
361
- XCTAssertEqual ( a1 , grads. 0 )
362
- XCTAssertEqual ( b1 , grads. 1 )
362
+ XCTAssertEqual ( grads. 0 , a1 )
363
+ XCTAssertEqual ( grads. 1 , b1 )
363
364
}
364
365
365
366
func testVJPConcatenationNegativeAxis( ) {
@@ -373,96 +374,96 @@ final class BasicOperatorTests: XCTestCase {
373
374
return ( a1 * a) . concatenated ( with: b1 * b, alongAxis: - 1 ) . sum ( )
374
375
}
375
376
376
- XCTAssertEqual ( a1 , grads. 0 )
377
- XCTAssertEqual ( b1 , grads. 1 )
377
+ XCTAssertEqual ( grads. 0 , a1 )
378
+ XCTAssertEqual ( grads. 1 , b1 )
378
379
}
379
380
380
381
func testTranspose( ) {
381
382
// 3 x 2 -> 2 x 3
382
383
let xT = Tensor < Float > ( [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ) . transposed ( )
383
384
let xTArray = xT. array
384
- XCTAssertEqual ( 2 , xTArray. rank)
385
- XCTAssertEqual ( [ 2 , 3 ] , xTArray . shape )
386
- XCTAssertEqual ( [ 1 , 3 , 5 , 2 , 4 , 6 ] , xTArray . scalars )
385
+ XCTAssertEqual ( xTArray. rank, 2 )
386
+ XCTAssertEqual ( xTArray . shape , [ 2 , 3 ] )
387
+ XCTAssertEqual ( xTArray . scalars , [ 1 , 3 , 5 , 2 , 4 , 6 ] )
387
388
}
388
389
389
390
func testReshape( ) {
390
391
// 2 x 3 -> 1 x 3 x 1 x 2 x 1
391
392
let matrix = Tensor < Int32 > ( [ [ 0 , 1 , 2 ] , [ 3 , 4 , 5 ] ] )
392
393
let reshaped = matrix. reshaped ( to: [ 1 , 3 , 1 , 2 , 1 ] )
393
394
394
- XCTAssertEqual ( [ 1 , 3 , 1 , 2 , 1 ] , reshaped . shape )
395
- XCTAssertEqual ( Array ( 0 ..< 6 ) , reshaped . scalars )
395
+ XCTAssertEqual ( reshaped . shape , [ 1 , 3 , 1 , 2 , 1 ] )
396
+ XCTAssertEqual ( reshaped . scalars , Array ( 0 ..< 6 ) )
396
397
}
397
398
398
399
func testFlatten( ) {
399
400
// 2 x 3 -> 6
400
401
let matrix = Tensor < Int32 > ( [ [ 0 , 1 , 2 ] , [ 3 , 4 , 5 ] ] )
401
402
let flattened = matrix. flattened ( )
402
403
403
- XCTAssertEqual ( [ 6 ] , flattened. shape)
404
- XCTAssertEqual ( Array ( 0 ..< 6 ) , flattened . scalars )
404
+ XCTAssertEqual ( flattened. shape, [ 6 ] )
405
+ XCTAssertEqual ( flattened . scalars , Array ( 0 ..< 6 ) )
405
406
}
406
407
407
408
func testFlatten0D( ) {
408
409
let scalar = Tensor < Float > ( 5 )
409
410
let flattened = scalar. flattened ( )
410
- XCTAssertEqual ( [ 1 ] , flattened. shape)
411
- XCTAssertEqual ( [ 5 ] , flattened. scalars)
411
+ XCTAssertEqual ( flattened. shape, [ 1 ] )
412
+ XCTAssertEqual ( flattened. scalars, [ 5 ] )
412
413
}
413
414
414
415
func testReshapeToScalar( ) {
415
416
// 1 x 1 -> scalar
416
417
let z = Tensor < Float > ( [ [ 10 ] ] ) . reshaped ( to: [ ] )
417
- XCTAssertEqual ( [ ] , z. shape)
418
+ XCTAssertEqual ( z. shape, [ ] )
418
419
}
419
420
420
421
func testReshapeTensor( ) {
421
422
// 2 x 3 -> 1 x 3 x 1 x 2 x 1
422
423
let x = Tensor < Float > ( repeating: 0.0 , shape: [ 2 , 3 ] )
423
424
let y = Tensor < Float > ( repeating: 0.0 , shape: [ 1 , 3 , 1 , 2 , 1 ] )
424
425
let result = x. reshaped ( like: y)
425
- XCTAssertEqual ( [ 1 , 3 , 1 , 2 , 1 ] , result . shape )
426
+ XCTAssertEqual ( result . shape , [ 1 , 3 , 1 , 2 , 1 ] )
426
427
}
427
428
428
429
func testUnbroadcast1( ) {
429
430
let x = Tensor < Float > ( repeating: 1 , shape: [ 2 , 3 , 4 , 5 ] )
430
431
let y = Tensor < Float > ( repeating: 1 , shape: [ 4 , 5 ] )
431
432
let z = x. unbroadcasted ( like: y)
432
- XCTAssertEqual ( ShapedArray < Float > ( repeating: 6 , shape: [ 4 , 5 ] ) , z . array )
433
+ XCTAssertEqual ( z . array , ShapedArray < Float > ( repeating: 6 , shape: [ 4 , 5 ] ) )
433
434
}
434
435
435
436
func testUnbroadcast2( ) {
436
437
let x = Tensor < Float > ( repeating: 1 , shape: [ 2 , 3 , 4 , 5 ] )
437
438
let y = Tensor < Float > ( repeating: 1 , shape: [ 3 , 1 , 5 ] )
438
439
let z = x. unbroadcasted ( like: y)
439
- XCTAssertEqual ( ShapedArray < Float > ( repeating: 8 , shape: [ 3 , 1 , 5 ] ) , z . array )
440
+ XCTAssertEqual ( z . array , ShapedArray < Float > ( repeating: 8 , shape: [ 3 , 1 , 5 ] ) )
440
441
}
441
442
442
443
func testSliceUpdate( ) {
443
444
var t1 = Tensor < Float > ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] )
444
445
t1 [ 0 ] = Tensor ( zeros: [ 3 ] )
445
- XCTAssertEqual ( ShapedArray ( shape: [ 2 , 3 ] , scalars: [ 0 , 0 , 0 , 4 , 5 , 6 ] ) , t1 . array )
446
+ XCTAssertEqual ( t1 . array , ShapedArray ( shape: [ 2 , 3 ] , scalars: [ 0 , 0 , 0 , 4 , 5 , 6 ] ) )
446
447
var t2 = t1
447
448
t2 [ 0 ] [ 2 ] = Tensor ( 3 )
448
- XCTAssertEqual ( ShapedArray ( shape: [ 2 , 3 ] , scalars: [ 0 , 0 , 3 , 4 , 5 , 6 ] ) , t2 . array )
449
+ XCTAssertEqual ( t2 . array , ShapedArray ( shape: [ 2 , 3 ] , scalars: [ 0 , 0 , 3 , 4 , 5 , 6 ] ) )
449
450
var t3 = Tensor < Bool > ( [ [ true , true , true ] , [ false , false , false ] ] )
450
451
t3 [ 0 ] [ 1 ] = Tensor ( false )
451
- XCTAssertEqual ( ShapedArray (
452
- shape: [ 2 , 3 ] , scalars: [ true , false , true , false , false , false ] ) , t3 . array )
452
+ XCTAssertEqual ( t3 . array , ShapedArray (
453
+ shape: [ 2 , 3 ] , scalars: [ true , false , true , false , false , false ] ) )
453
454
var t4 = Tensor < Bool > ( [ [ true , true , true ] , [ false , false , false ] ] )
454
455
t4 [ 0 ] = Tensor ( repeating: false , shape: [ 3 ] )
455
- XCTAssertEqual ( ShapedArray ( repeating: false , shape: [ 2 , 3 ] ) , t4 . array )
456
+ XCTAssertEqual ( t4 . array , ShapedArray ( repeating: false , shape: [ 2 , 3 ] ) )
456
457
}
457
458
458
459
func testBroadcastTensor( ) {
459
460
// 1 -> 2 x 3 x 4
460
461
let one = Tensor < Float > ( 1 )
461
462
var target = Tensor < Float > ( repeating: 0.0 , shape: [ 2 , 3 , 4 ] )
462
463
let broadcasted = one. broadcasted ( like: target)
463
- XCTAssertEqual ( Tensor ( repeating: 1 , shape: [ 2 , 3 , 4 ] ) , broadcasted )
464
+ XCTAssertEqual ( broadcasted , Tensor ( repeating: 1 , shape: [ 2 , 3 , 4 ] ) )
464
465
target .= Tensor ( repeating: 1 , shape: [ 1 , 3 , 1 ] )
465
- XCTAssertEqual ( Tensor ( repeating: 1 , shape: [ 2 , 3 , 4 ] ) , target )
466
+ XCTAssertEqual ( target , Tensor ( repeating: 1 , shape: [ 2 , 3 , 4 ] ) )
466
467
}
467
468
468
469
static var allTests = [
0 commit comments