@@ -360,6 +360,26 @@ final class LayerTests: XCTestCase {
360
360
XCTAssertEqual ( output, expected)
361
361
}
362
362
363
+ func testAvgPool1DGradient( ) {
364
+ let layer = AvgPool1D < Float > ( poolSize: 2 , stride: 1 , padding: . valid)
365
+ let x = Tensor ( shape: [ 1 , 4 , 4 ] , scalars: ( 0 ..< 16 ) . map ( Float . init) )
366
+ let computedGradient = gradient ( at: x, layer) { $1 ( $0) . sum ( ) }
367
+ // The expected value of the gradient was computed using the following Python code:
368
+ // ```
369
+ // avgpool1D = tf.keras.layers.AvgPool1D(strides=1)
370
+ // with tf.GradientTape() as t:
371
+ // t.watch(x)
372
+ // y = tf.math.reduce_sum(avgpool1D(x))
373
+ // print(t.gradient(y, x))
374
+ // ```
375
+ let expectedGradient = Tensor < Float > ( [ [
376
+ [ 0.5 , 0.5 , 0.5 , 0.5 ] ,
377
+ [ 1.0 , 1.0 , 1.0 , 1.0 ] ,
378
+ [ 1.0 , 1.0 , 1.0 , 1.0 ] ,
379
+ [ 0.5 , 0.5 , 0.5 , 0.5 ] ] ] )
380
+ XCTAssertEqual ( computedGradient. 0 , expectedGradient)
381
+ }
382
+
363
383
func testAvgPool2D( ) {
364
384
let layer = AvgPool2D < Float > ( poolSize: ( 2 , 5 ) , strides: ( 1 , 1 ) , padding: . valid)
365
385
let input = Tensor ( shape: [ 1 , 2 , 5 , 1 ] , scalars: ( 0 ..< 10 ) . map ( Float . init) )
@@ -368,6 +388,26 @@ final class LayerTests: XCTestCase {
368
388
XCTAssertEqual ( output, expected)
369
389
}
370
390
391
+ func testAvgPool2DGradient( ) {
392
+ let layer = AvgPool2D < Float > ( poolSize: ( 2 , 2 ) , strides: ( 1 , 1 ) , padding: . valid)
393
+ let x = Tensor ( shape: [ 1 , 4 , 4 , 2 ] , scalars: ( 0 ..< 32 ) . map ( Float . init) )
394
+ let computedGradient = gradient ( at: x, layer) { $1 ( $0) . sum ( ) }
395
+ // The expected value of the gradient was computed using the following Python code:
396
+ // ```
397
+ // avgpool2D = tf.keras.layers.AvgPool2D(strides=(1, 1))
398
+ // with tf.GradientTape() as t:
399
+ // t.watch(x)
400
+ // y = tf.math.reduce_sum(avgpool2D(x))
401
+ // print(t.gradient(y, x))
402
+ // ```
403
+ let expectedGradient = Tensor < Float > ( [ [
404
+ [ [ 0.25 , 0.25 ] , [ 0.50 , 0.50 ] , [ 0.50 , 0.50 ] , [ 0.25 , 0.25 ] ] ,
405
+ [ [ 0.50 , 0.50 ] , [ 1.00 , 1.00 ] , [ 1.00 , 1.00 ] , [ 0.50 , 0.50 ] ] ,
406
+ [ [ 0.50 , 0.50 ] , [ 1.00 , 1.00 ] , [ 1.00 , 1.00 ] , [ 0.50 , 0.50 ] ] ,
407
+ [ [ 0.25 , 0.25 ] , [ 0.50 , 0.50 ] , [ 0.50 , 0.50 ] , [ 0.25 , 0.25 ] ] ] ] )
408
+ XCTAssertEqual ( computedGradient. 0 , expectedGradient)
409
+ }
410
+
371
411
func testAvgPool3D( ) {
372
412
let layer = AvgPool3D < Float > ( poolSize: ( 2 , 4 , 5 ) , strides: ( 1 , 1 , 1 ) , padding: . valid)
373
413
let input = Tensor ( shape: [ 1 , 2 , 4 , 5 , 1 ] , scalars: ( 0 ..< 40 ) . map ( Float . init) )
@@ -376,6 +416,22 @@ final class LayerTests: XCTestCase {
376
416
XCTAssertEqual ( output, expected)
377
417
}
378
418
419
+ func testAvgPool3DGradient( ) {
420
+ let layer = AvgPool3D < Float > ( poolSize: ( 2 , 2 , 2 ) , strides: ( 1 , 1 , 1 ) , padding: . valid)
421
+ let x = Tensor ( shape: [ 1 , 2 , 2 , 2 , 1 ] , scalars: ( 0 ..< 8 ) . map ( Float . init) )
422
+ let computedGradient = gradient ( at: x, layer) { $1 ( $0) . sum ( ) }
423
+ // The expected value of the gradient was computed using the following Python code:
424
+ // ```
425
+ // avgpool3D = tf.keras.layers.AvgPool3D(strides=(1, 1, 1))
426
+ // with tf.GradientTape() as t:
427
+ // t.watch(x)
428
+ // y = tf.math.reduce_sum(avgpool3D(x))
429
+ // print(t.gradient(y, x))
430
+ // ```
431
+ let expectedGradient = Tensor < Float > ( repeating: 0.125 , shape: [ 1 , 2 , 2 , 2 , 1 ] )
432
+ XCTAssertEqual ( computedGradient. 0 , expectedGradient)
433
+ }
434
+
379
435
func testGlobalAvgPool1D( ) {
380
436
let layer = GlobalAvgPool1D < Float > ( )
381
437
let input = Tensor ( shape: [ 2 , 5 , 1 ] , scalars: ( 0 ..< 10 ) . map ( Float . init) )
@@ -1018,8 +1074,11 @@ final class LayerTests: XCTestCase {
1018
1074
( " testMaxPool3D " , testMaxPool3D) ,
1019
1075
( " testMaxPool3DGradient " , testMaxPool3DGradient) ,
1020
1076
( " testAvgPool1D " , testAvgPool1D) ,
1077
+ ( " testAvgPool1DGradient " , testAvgPool1DGradient) ,
1021
1078
( " testAvgPool2D " , testAvgPool2D) ,
1079
+ ( " testAvgPool2DGradient " , testAvgPool2DGradient) ,
1022
1080
( " testAvgPool3D " , testAvgPool3D) ,
1081
+ ( " testAvgPool3DGradient " , testAvgPool3DGradient) ,
1023
1082
( " testGlobalAvgPool1D " , testGlobalAvgPool1D) ,
1024
1083
( " testGlobalAvgPool1DGradient " , testGlobalAvgPool1DGradient) ,
1025
1084
( " testGlobalAvgPool2D " , testGlobalAvgPool2D) ,
0 commit comments