@@ -40,6 +40,39 @@ final class BasicOperatorTests: XCTestCase {
40
40
XCTAssertEqual ( paddedTensor, target)
41
41
}
42
42
43
+ func testPaddedConstant( ) {
44
+ let x = Tensor < Float > ( ones: [ 2 , 2 ] )
45
+ let target = Tensor < Float > ( [ [ 3 , 3 , 3 ] , [ 1 , 1 , 3 ] , [ 1 , 1 , 3 ] ] )
46
+ let paddedTensor = x. padded ( forSizes: [ ( 1 , 0 ) , ( 0 , 1 ) ] , mode: . constant( 3.0 ) )
47
+ XCTAssertEqual ( paddedTensor, target)
48
+ }
49
+
50
+ func testPaddedReflect( ) {
51
+ let x = Tensor < Float > ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] , [ 7 , 8 , 9 ] ] )
52
+ let target = Tensor < Float > ( [
53
+ [ 7 , 8 , 9 , 8 , 7 ] ,
54
+ [ 4 , 5 , 6 , 5 , 4 ] ,
55
+ [ 1 , 2 , 3 , 2 , 1 ] ,
56
+ [ 4 , 5 , 6 , 5 , 4 ] ,
57
+ [ 7 , 8 , 9 , 8 , 7 ]
58
+ ] )
59
+ let paddedTensor = x. padded ( forSizes: [ ( 2 , 0 ) , ( 0 , 2 ) ] , mode: . reflect)
60
+ XCTAssertEqual ( paddedTensor, target)
61
+ }
62
+
63
+ func testPaddedSymmetric( ) {
64
+ let x = Tensor < Float > ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] , [ 7 , 8 , 9 ] ] )
65
+ let target = Tensor < Float > ( [
66
+ [ 4 , 5 , 6 , 6 , 5 ] ,
67
+ [ 1 , 2 , 3 , 3 , 2 ] ,
68
+ [ 1 , 2 , 3 , 3 , 2 ] ,
69
+ [ 4 , 5 , 6 , 6 , 5 ] ,
70
+ [ 7 , 8 , 9 , 9 , 8 ]
71
+ ] )
72
+ let paddedTensor = x. padded ( forSizes: [ ( 2 , 0 ) , ( 0 , 2 ) ] , mode: . symmetric)
73
+ XCTAssertEqual ( paddedTensor, target)
74
+ }
75
+
43
76
func testVJPPadded( ) {
44
77
let x = Tensor < Float > ( ones: [ 3 , 2 ] )
45
78
let target = Tensor < Float > ( [ [ 2 , 2 ] , [ 2 , 2 ] , [ 2 , 2 ] ] )
@@ -50,6 +83,36 @@ final class BasicOperatorTests: XCTestCase {
50
83
XCTAssertEqual ( grads, target)
51
84
}
52
85
86
+ func testVJPPaddedConstant( ) {
87
+ let x = Tensor < Float > ( ones: [ 3 , 2 ] )
88
+ let target = Tensor < Float > ( [ [ 2 , 2 ] , [ 2 , 2 ] , [ 2 , 2 ] ] )
89
+ let grads = x. gradient { a -> Tensor < Float > in
90
+ let paddedTensor = a. padded ( forSizes: [ ( 1 , 0 ) , ( 0 , 1 ) ] , mode: . constant( 3.0 ) )
91
+ return ( paddedTensor * paddedTensor) . sum ( )
92
+ }
93
+ XCTAssertEqual ( grads, target)
94
+ }
95
+
96
+ func testVJPPaddedReflect( ) {
97
+ let x = Tensor < Float > ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] , [ 7 , 8 , 9 ] ] )
98
+ let target = Tensor < Float > ( [ [ 4 , 8 , 6 ] , [ 32 , 40 , 24 ] , [ 56 , 64 , 36 ] ] )
99
+ let grads = x. gradient { a -> Tensor < Float > in
100
+ let paddedTensor = a. padded ( forSizes: [ ( 2 , 0 ) , ( 0 , 2 ) ] , mode: . reflect)
101
+ return ( paddedTensor * paddedTensor) . sum ( )
102
+ }
103
+ XCTAssertEqual ( grads, target)
104
+ }
105
+
106
+ func testVJPPaddedSymmetric( ) {
107
+ let x = Tensor < Float > ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] , [ 7 , 8 , 9 ] ] )
108
+ let target = Tensor < Float > ( [ [ 4 , 16 , 24 ] , [ 16 , 40 , 48 ] , [ 14 , 32 , 36 ] ] )
109
+ let grads = x. gradient { a -> Tensor < Float > in
110
+ let paddedTensor = a. padded ( forSizes: [ ( 2 , 0 ) , ( 0 , 2 ) ] , mode: . symmetric)
111
+ return ( paddedTensor * paddedTensor) . sum ( )
112
+ }
113
+ XCTAssertEqual ( grads, target)
114
+ }
115
+
53
116
func testElementIndexing( ) {
54
117
// NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly
55
118
// until send and receive are implemented (without writing a bunch of mini
@@ -599,7 +662,13 @@ final class BasicOperatorTests: XCTestCase {
599
662
( " testGathering " , testGathering) ,
600
663
( " testBatchGathering " , testBatchGathering) ,
601
664
( " testPadded " , testPadded) ,
665
+ ( " testPaddedConstant " , testPaddedConstant) ,
666
+ ( " testPaddedReflect " , testPaddedReflect) ,
667
+ ( " testPaddedSymmetric " , testPaddedSymmetric) ,
602
668
( " testVJPPadded " , testVJPPadded) ,
669
+ ( " testVJPPaddedConstant " , testVJPPaddedConstant) ,
670
+ ( " testVJPPaddedReflect " , testVJPPaddedReflect) ,
671
+ ( " testVJPPaddedSymmetric " , testVJPPaddedSymmetric) ,
603
672
( " testElementIndexing " , testElementIndexing) ,
604
673
( " testElementIndexingAssignment " , testElementIndexingAssignment) ,
605
674
( " testNestedElementIndexing " , testNestedElementIndexing) ,
0 commit comments