@@ -443,19 +443,54 @@ final class BasicOperatorTests: XCTestCase {
443
443
XCTAssertEqual ( result. shape, [ 1 , 3 , 1 , 2 , 1 ] )
444
444
}
445
445
446
- func testUnbroadcast1 ( ) {
446
+ func testUnbroadcastRank4ToRank2 ( ) {
447
447
let x = Tensor < Float > ( repeating: 1 , shape: [ 2 , 3 , 4 , 5 ] )
448
448
let y = Tensor < Float > ( repeating: 1 , shape: [ 4 , 5 ] )
449
449
let z = x. unbroadcasted ( like: y)
450
450
XCTAssertEqual ( z. array, ShapedArray < Float > ( repeating: 6 , shape: [ 4 , 5 ] ) )
451
451
}
452
452
453
- func testUnbroadcast2 ( ) {
453
+ func testUnbroadcastRank4ToRank3 ( ) {
454
454
let x = Tensor < Float > ( repeating: 1 , shape: [ 2 , 3 , 4 , 5 ] )
455
455
let y = Tensor < Float > ( repeating: 1 , shape: [ 3 , 1 , 5 ] )
456
456
let z = x. unbroadcasted ( like: y)
457
457
XCTAssertEqual ( z. array, ShapedArray < Float > ( repeating: 8 , shape: [ 3 , 1 , 5 ] ) )
458
458
}
459
+
460
+ func testUnbroadcast3x3To1x3( ) {
461
+ func foo( tensor: Tensor < Float > , shape: Tensor < Int32 > ) -> Tensor < Float > {
462
+ tensor. unbroadcasted ( toShape: shape)
463
+ }
464
+
465
+ // [3,3] -> [1,3]
466
+ let atTensor : Tensor < Float > = [
467
+ [ 1 , 2 , 3 ] ,
468
+ [ 1 , 2 , 3 ] ,
469
+ [ 1 , 2 , 3 ] ]
470
+ var pb : ( Tensor < Float > ) -> Tensor < Float > = pullback ( at: atTensor) { x in
471
+ foo ( tensor: x, shape: [ 1 , 3 ] )
472
+ }
473
+
474
+ // Same shape as parameter of pullback
475
+ var inputTensor : Tensor < Float > = [ [ 1 , 2 , 3 ] ]
476
+ var expected : Tensor < Float > = atTensor
477
+ XCTAssertEqual ( expected, pb ( inputTensor) )
478
+ // Different shape than parameter of pullback
479
+ inputTensor = [ 2 ]
480
+ expected = [
481
+ [ 2 , 2 , 2 ] ,
482
+ [ 2 , 2 , 2 ] ,
483
+ [ 2 , 2 , 2 ] ]
484
+ XCTAssertEqual ( expected, pb ( inputTensor) )
485
+
486
+ // Same shape as tensor we are differentiating at
487
+ inputTensor = [
488
+ [ 8 , 1 , 3 ] ,
489
+ [ 8 , 1 , 3 ] ,
490
+ [ 8 , 1 , 3 ] ]
491
+ expected = inputTensor
492
+ XCTAssertEqual ( expected, pb ( inputTensor) )
493
+ }
459
494
460
495
func testSliceUpdate( ) {
461
496
var t1 = Tensor < Float > ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] )
@@ -482,6 +517,82 @@ final class BasicOperatorTests: XCTestCase {
482
517
target .= Tensor ( repeating: 1 , shape: [ 1 , 3 , 1 ] )
483
518
XCTAssertEqual ( target, Tensor ( repeating: 1 , shape: [ 2 , 3 , 4 ] ) )
484
519
}
520
+
521
+ func testBroadcast3x0To3x3( ) {
522
+ func foo( tensor: Tensor < Float > , shape: Tensor < Int32 > ) -> Tensor < Float > {
523
+ tensor. broadcasted ( toShape: shape)
524
+ }
525
+
526
+ // [3,] -> [3,3]
527
+ var pb : ( Tensor < Float > ) -> Tensor < Float > = pullback ( at: [ 99 , 33 , 55 ] ) { x in
528
+ foo ( tensor: x, shape: [ 3 , 3 ] )
529
+ }
530
+
531
+ // Same shape as parameter of pullback
532
+ var inputTensor : Tensor < Float > = [
533
+ [ 1 , 2 , 3 ] ,
534
+ [ 1 , 2 , 3 ] ,
535
+ [ 1 , 2 , 3 ] ]
536
+ var expected : Tensor < Float > = [ 3 , 6 , 9 ]
537
+ XCTAssertEqual ( expected, pb ( inputTensor) )
538
+
539
+ // Different shape than parameter of pullback
540
+ inputTensor = [
541
+ [ 1 , 2 , 3 ] ,
542
+ [ 1 , 2 , 3 ] ,
543
+ [ 1 , 2 , 3 ] ,
544
+ [ 1 , 2 , 3 ] ]
545
+ expected = [ 4 , 8 , 12 ]
546
+ XCTAssertEqual ( expected, pb ( inputTensor) )
547
+
548
+ // Same shape as tensor we are differentiating at
549
+ inputTensor = [ 1 , 2 , 3 ]
550
+ expected = [ 1 , 2 , 3 ]
551
+ XCTAssertEqual ( expected, pb ( inputTensor) )
552
+
553
+ // Extremely padded shape as tensor we are differentiating at
554
+ inputTensor = [ [ [ [ [ [ 1 , 2 , 3 ] ] ] ] ] ]
555
+ expected = [ 1 , 2 , 3 ]
556
+ XCTAssertEqual ( expected, pb ( inputTensor) )
557
+ }
558
+
559
+ func testBroadcast3x1To3x3( ) {
560
+ func foo( tensor: Tensor < Float > , shape: Tensor < Int32 > ) -> Tensor < Float > {
561
+ tensor. broadcasted ( toShape: shape)
562
+ }
563
+
564
+ // [3,1] -> [3x3]
565
+ var pb : ( Tensor < Float > ) -> Tensor < Float > = pullback ( at: [ [ 99 , 33 , 55 ] ] ) { x in
566
+ foo ( tensor: x, shape: [ 3 , 3 ] )
567
+ }
568
+
569
+ // Same shape as parameter of pullback
570
+ var inputTensor : Tensor < Float > = [
571
+ [ 1 , 2 , 3 ] ,
572
+ [ 1 , 2 , 3 ] ,
573
+ [ 1 , 2 , 3 ] ]
574
+ var expected : Tensor < Float > = [ [ 3 , 6 , 9 ] ]
575
+ XCTAssertEqual ( expected, pb ( inputTensor) )
576
+
577
+ // Different shape than parameter of pullback
578
+ inputTensor = [
579
+ [ 1 , 2 , 3 ] ,
580
+ [ 1 , 2 , 3 ] ,
581
+ [ 1 , 2 , 3 ] ,
582
+ [ 1 , 2 , 3 ] ]
583
+ expected = [ [ 4 , 8 , 12 ] ]
584
+ XCTAssertEqual ( expected, pb ( inputTensor) )
585
+
586
+ // Same shape as tensor we are differentiating at
587
+ inputTensor = [ [ 1 , 2 , 3 ] ]
588
+ expected = [ [ 1 , 2 , 3 ] ]
589
+ XCTAssertEqual ( expected, pb ( inputTensor) )
590
+
591
+ // Extremely padded shape of tensor we are differentiating at
592
+ inputTensor = [ [ [ [ [ [ 1 , 2 , 3 ] ] ] ] ] ]
593
+ expected = [ [ 1 , 2 , 3 ] ]
594
+ XCTAssertEqual ( expected, pb ( inputTensor) )
595
+ }
485
596
486
597
static var allTests = [
487
598
( " testGathering " , testGathering) ,
@@ -507,9 +618,12 @@ final class BasicOperatorTests: XCTestCase {
507
618
( " testFlatten0D " , testFlatten0D) ,
508
619
( " testReshapeToScalar " , testReshapeToScalar) ,
509
620
( " testReshapeTensor " , testReshapeTensor) ,
510
- ( " testUnbroadcast1 " , testUnbroadcast1) ,
511
- ( " testUnbroadcast2 " , testUnbroadcast2) ,
621
+ ( " testUnbroadcastRank4ToRank2 " , testUnbroadcastRank4ToRank2) ,
622
+ ( " testUnbroadcastRank4ToRank3 " , testUnbroadcastRank4ToRank3) ,
623
+ ( " testUnbroadcast3x3To1x3 " , testUnbroadcast3x3To1x3) ,
512
624
( " testSliceUpdate " , testSliceUpdate) ,
625
+ ( " testBroadcast3x0To3x3 " , testBroadcast3x0To3x3) ,
626
+ ( " testBroadcast3x1To3x3 " , testBroadcast3x1To3x3) ,
513
627
( " testBroadcastTensor " , testBroadcastTensor)
514
628
]
515
629
}
0 commit comments