@@ -214,6 +214,26 @@ final class LayerTests: XCTestCase {
214
214
XCTAssertEqual ( output, expected)
215
215
}
216
216
217
+ func testMaxPool1DGradient( ) {
218
+ let layer = MaxPool1D < Float > ( poolSize: 2 , stride: 1 , padding: . valid)
219
+ let x = Tensor < Float > ( shape: [ 1 , 4 , 4 ] , scalars: ( 0 ..< 16 ) . map ( Float . init) )
220
+ let computedGradient = gradient ( at: x, layer) { $1 ( $0) . sum ( ) }
221
+ // The expected value of the gradient was computed using the following Python code:
222
+ // ```
223
+ // maxpool1D = tf.keras.layers.MaxPool1D()
224
+ // with tf.GradientTape() as t:
225
+ // t.watch(x)
226
+ // y = tf.math.reduce_sum(maxpool1D(x))
227
+ // print(t.gradient(y, x))
228
+ // ```
229
+ let expectedGradient = Tensor < Float > ( [ [
230
+ [ 0 , 0 , 0 , 0 ] ,
231
+ [ 1 , 1 , 1 , 1 ] ,
232
+ [ 1 , 1 , 1 , 1 ] ,
233
+ [ 1 , 1 , 1 , 1 ] ] ] )
234
+ XCTAssertEqual ( computedGradient. 0 , expectedGradient)
235
+ }
236
+
217
237
func testMaxPool2D( ) {
218
238
let layer = MaxPool2D < Float > ( poolSize: ( 2 , 2 ) , strides: ( 1 , 1 ) , padding: . valid)
219
239
let input = Tensor ( shape: [ 1 , 2 , 2 , 1 ] , scalars: ( 0 ..< 4 ) . map ( Float . init) )
@@ -222,6 +242,26 @@ final class LayerTests: XCTestCase {
222
242
XCTAssertEqual ( output, expected)
223
243
}
224
244
245
+ func testMaxPool2DGradient( ) {
246
+ let layer = MaxPool2D < Float > ( poolSize: ( 2 , 2 ) , strides: ( 2 , 2 ) , padding: . valid)
247
+ let x = Tensor ( shape: [ 1 , 4 , 4 , 1 ] , scalars: ( 0 ..< 16 ) . map ( Float . init) )
248
+ let computedGradient = gradient ( at: x, layer) { $1 ( $0) . sum ( ) }
249
+ // The expected value of the gradient was computed using the following Python code:
250
+ // ```
251
+ // maxpool2D = tf.keras.layers.MaxPool2D(strides=(2, 2))
252
+ // with tf.GradientTape() as t:
253
+ // t.watch(x)
254
+ // y = tf.math.reduce_sum(maxpool2D(x))
255
+ // print(t.gradient(y, x))
256
+ // ```
257
+ let expectedGradient = Tensor < Float > ( [ [
258
+ [ [ 0 ] , [ 0 ] , [ 0 ] , [ 0 ] ] ,
259
+ [ [ 0 ] , [ 1 ] , [ 0 ] , [ 1 ] ] ,
260
+ [ [ 0 ] , [ 0 ] , [ 0 ] , [ 0 ] ] ,
261
+ [ [ 0 ] , [ 1 ] , [ 0 ] , [ 1 ] ] ] ] )
262
+ XCTAssertEqual ( computedGradient. 0 , expectedGradient)
263
+ }
264
+
225
265
func testMaxPool3D( ) {
226
266
let layer = MaxPool3D < Float > ( poolSize: ( 2 , 2 , 2 ) , strides: ( 1 , 1 , 1 ) , padding: . valid)
227
267
let input = Tensor ( shape: [ 1 , 2 , 2 , 2 , 1 ] , scalars: ( 0 ..< 8 ) . map ( Float . init) )
@@ -230,6 +270,26 @@ final class LayerTests: XCTestCase {
230
270
XCTAssertEqual ( output, expected)
231
271
}
232
272
273
+ func testMaxPool3DGradient( ) {
274
+ let layer = MaxPool3D < Float > ( poolSize: ( 2 , 2 , 2 ) , strides: ( 1 , 1 , 1 ) , padding: . valid)
275
+ let x = Tensor ( shape: [ 1 , 2 , 2 , 2 , 1 ] , scalars: ( 0 ..< 8 ) . map ( Float . init) )
276
+ let computedGradient = gradient ( at: x, layer) { $1 ( $0) . sum ( ) }
277
+ // The expected value of the gradient was computed using the following Python code:
278
+ // ```
279
+ // maxpool3D = tf.keras.layers.MaxPool3D(strides=(1, 1, 1))
280
+ // with tf.GradientTape() as t:
281
+ // t.watch(x)
282
+ // y = tf.math.reduce_sum(maxpool3D(x))
283
+ // print(t.gradient(y, x))
284
+ // ```
285
+ let expectedGradient = Tensor < Float > ( [ [
286
+ [ [ [ 0 ] , [ 0 ] ] ,
287
+ [ [ 0 ] , [ 0 ] ] ] ,
288
+ [ [ [ 0 ] , [ 0 ] ] ,
289
+ [ [ 0 ] , [ 1 ] ] ] ] ] )
290
+ XCTAssertEqual ( computedGradient. 0 , expectedGradient)
291
+ }
292
+
233
293
func testAvgPool1D( ) {
234
294
let layer = AvgPool1D < Float > ( poolSize: 3 , stride: 1 , padding: . valid)
235
295
let input = Tensor < Float > ( [ [ 0 , 1 , 2 , 3 , 4 ] , [ 10 , 11 , 12 , 13 , 14 ] ] ) . expandingShape ( at: 2 )
@@ -645,8 +705,11 @@ final class LayerTests: XCTestCase {
645
705
( " testZeroPadding2D " , testZeroPadding2D) ,
646
706
( " testZeroPadding3D " , testZeroPadding3D) ,
647
707
( " testMaxPool1D " , testMaxPool1D) ,
708
+ ( " testMaxPool1DGradient " , testMaxPool1DGradient) ,
648
709
( " testMaxPool2D " , testMaxPool2D) ,
710
+ ( " testMaxPool2DGradient " , testMaxPool2DGradient) ,
649
711
( " testMaxPool3D " , testMaxPool3D) ,
712
+ ( " testMaxPool3DGradient " , testMaxPool3DGradient) ,
650
713
( " testAvgPool1D " , testAvgPool1D) ,
651
714
( " testAvgPool2D " , testAvgPool2D) ,
652
715
( " testAvgPool3D " , testAvgPool3D) ,
0 commit comments