@@ -110,6 +110,52 @@ final class LayerTests: XCTestCase {
110
110
XCTAssertEqual ( output, expected)
111
111
}
112
112
113
+ func testConv2DGradient( ) {
114
+ let filter = Tensor ( shape: [ 3 , 3 , 2 , 4 ] , scalars: ( 0 ..< 72 ) . map ( Float . init) )
115
+ let bias = Tensor < Float > ( zeros: [ 4 ] )
116
+ let layer = Conv2D < Float > ( filter: filter,
117
+ bias: bias,
118
+ activation: identity,
119
+ strides: ( 2 , 2 ) ,
120
+ padding: . valid)
121
+ let input = Tensor ( shape: [ 2 , 4 , 4 , 2 ] , scalars: ( 0 ..< 64 ) . map ( Float . init) )
122
+ let grads = gradient ( at: input, layer) { $1 ( $0) . sum ( ) }
123
+ // The expected gradients were computed using the following Python code:
124
+ // ```
125
+ // x = tf.reshape(tf.range(64, dtype=tf.float32), [2, 4, 4, 2])
126
+ // filter = tf.reshape(tf.range(72, dtype=tf.float32), [3, 3, 2, 4])
127
+ // bias = tf.zeros([4])
128
+ // with tf.GradientTape() as t:
129
+ // t.watch([x, filter, bias])
130
+ // y = tf.math.reduce_sum(tf.nn.conv2d(input=x,
131
+ // filters=filter,
132
+ // strides=[1, 2, 2, 1],
133
+ // data_format="NHWC",
134
+ // padding="VALID") + bias)
135
+ // grads = t.gradient(y, [x, filter, bias])
136
+ // ```
137
+ XCTAssertEqual ( grads. 0 ,
138
+ [ [ [ [ 6 , 22 ] , [ 38 , 54 ] , [ 70 , 86 ] , [ 0 , 0 ] ] ,
139
+ [ [ 102 , 118 ] , [ 134 , 150 ] , [ 166 , 182 ] , [ 0 , 0 ] ] ,
140
+ [ [ 198 , 214 ] , [ 230 , 246 ] , [ 262 , 278 ] , [ 0 , 0 ] ] ,
141
+ [ [ 0 , 0 ] , [ 0 , 0 ] , [ 0 , 0 ] , [ 0 , 0 ] ] ] ,
142
+ [ [ [ 6 , 22 ] , [ 38 , 54 ] , [ 70 , 86 ] , [ 0 , 0 ] ] ,
143
+ [ [ 102 , 118 ] , [ 134 , 150 ] , [ 166 , 182 ] , [ 0 , 0 ] ] ,
144
+ [ [ 198 , 214 ] , [ 230 , 246 ] , [ 262 , 278 ] , [ 0 , 0 ] ] ,
145
+ [ [ 0 , 0 ] , [ 0 , 0 ] , [ 0 , 0 ] , [ 0 , 0 ] ] ] ] )
146
+ XCTAssertEqual ( grads. 1 . filter,
147
+ [ [ [ [ 32 , 32 , 32 , 32 ] , [ 34 , 34 , 34 , 34 ] ] ,
148
+ [ [ 36 , 36 , 36 , 36 ] , [ 38 , 38 , 38 , 38 ] ] ,
149
+ [ [ 40 , 40 , 40 , 40 ] , [ 42 , 42 , 42 , 42 ] ] ] ,
150
+ [ [ [ 48 , 48 , 48 , 48 ] , [ 50 , 50 , 50 , 50 ] ] ,
151
+ [ [ 52 , 52 , 52 , 52 ] , [ 54 , 54 , 54 , 54 ] ] ,
152
+ [ [ 56 , 56 , 56 , 56 ] , [ 58 , 58 , 58 , 58 ] ] ] ,
153
+ [ [ [ 64 , 64 , 64 , 64 ] , [ 66 , 66 , 66 , 66 ] ] ,
154
+ [ [ 68 , 68 , 68 , 68 ] , [ 70 , 70 , 70 , 70 ] ] ,
155
+ [ [ 72 , 72 , 72 , 72 ] , [ 74 , 74 , 74 , 74 ] ] ] ] )
156
+ XCTAssertEqual ( grads. 1 . bias, [ 2 , 2 , 2 , 2 ] )
157
+ }
158
+
113
159
func testConv2DDilation( ) {
114
160
// Input shapes. (Data format = NHWC)
115
161
let batchSize = 2
@@ -697,6 +743,7 @@ final class LayerTests: XCTestCase {
697
743
( " testConv1D " , testConv1D) ,
698
744
( " testConv1DDilation " , testConv1DDilation) ,
699
745
( " testConv2D " , testConv2D) ,
746
+ ( " testConv2DGradient " , testConv2DGradient) ,
700
747
( " testConv2DDilation " , testConv2DDilation) ,
701
748
( " testConv3D " , testConv3D) ,
702
749
( " testDepthConv2D " , testDepthConv2D) ,
0 commit comments