Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 89f6104

Browse files
jon-toweaplatanios
authored andcommitted
Add AvgPool derivative tests. (#472)
1 parent 959c48a commit 89f6104

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,26 @@ final class LayerTests: XCTestCase {
360360
XCTAssertEqual(output, expected)
361361
}
362362

363+
func testAvgPool1DGradient() {
364+
let layer = AvgPool1D<Float>(poolSize: 2, stride: 1, padding: .valid)
365+
let x = Tensor(shape: [1, 4, 4], scalars: (0..<16).map(Float.init))
366+
let computedGradient = gradient(at: x, layer) { $1($0).sum() }
367+
// The expected value of the gradient was computed using the following Python code:
368+
// ```
369+
// avgpool1D = tf.keras.layers.AvgPool1D(strides=1)
370+
// with tf.GradientTape() as t:
371+
// t.watch(x)
372+
// y = tf.math.reduce_sum(avgpool1D(x))
373+
// print(t.gradient(y, x))
374+
// ```
375+
let expectedGradient = Tensor<Float>([[
376+
[0.5, 0.5, 0.5, 0.5],
377+
[1.0, 1.0, 1.0, 1.0],
378+
[1.0, 1.0, 1.0, 1.0],
379+
[0.5, 0.5, 0.5, 0.5]]])
380+
XCTAssertEqual(computedGradient.0, expectedGradient)
381+
}
382+
363383
func testAvgPool2D() {
364384
let layer = AvgPool2D<Float>(poolSize: (2, 5), strides: (1, 1), padding: .valid)
365385
let input = Tensor(shape: [1, 2, 5, 1], scalars: (0..<10).map(Float.init))
@@ -368,6 +388,26 @@ final class LayerTests: XCTestCase {
368388
XCTAssertEqual(output, expected)
369389
}
370390

391+
func testAvgPool2DGradient() {
392+
let layer = AvgPool2D<Float>(poolSize: (2, 2), strides: (1, 1), padding: .valid)
393+
let x = Tensor(shape: [1, 4, 4, 2], scalars: (0..<32).map(Float.init))
394+
let computedGradient = gradient(at: x, layer) { $1($0).sum() }
395+
// The expected value of the gradient was computed using the following Python code:
396+
// ```
397+
// avgpool2D = tf.keras.layers.AvgPool2D(strides=(1, 1))
398+
// with tf.GradientTape() as t:
399+
// t.watch(x)
400+
// y = tf.math.reduce_sum(avgpool2D(x))
401+
// print(t.gradient(y, x))
402+
// ```
403+
let expectedGradient = Tensor<Float>([[
404+
[[0.25, 0.25], [0.50, 0.50], [0.50, 0.50], [0.25, 0.25]],
405+
[[0.50, 0.50], [1.00, 1.00], [1.00, 1.00], [0.50, 0.50]],
406+
[[0.50, 0.50], [1.00, 1.00], [1.00, 1.00], [0.50, 0.50]],
407+
[[0.25, 0.25], [0.50, 0.50], [0.50, 0.50], [0.25, 0.25]]]])
408+
XCTAssertEqual(computedGradient.0, expectedGradient)
409+
}
410+
371411
func testAvgPool3D() {
372412
let layer = AvgPool3D<Float>(poolSize: (2, 4, 5), strides: (1, 1, 1), padding: .valid)
373413
let input = Tensor(shape: [1, 2, 4, 5, 1], scalars: (0..<40).map(Float.init))
@@ -376,6 +416,22 @@ final class LayerTests: XCTestCase {
376416
XCTAssertEqual(output, expected)
377417
}
378418

419+
func testAvgPool3DGradient() {
420+
let layer = AvgPool3D<Float>(poolSize: (2, 2, 2), strides: (1, 1, 1), padding: .valid)
421+
let x = Tensor(shape: [1, 2, 2, 2, 1], scalars: (0..<8).map(Float.init))
422+
let computedGradient = gradient(at: x, layer) { $1($0).sum() }
423+
// The expected value of the gradient was computed using the following Python code:
424+
// ```
425+
// avgpool3D = tf.keras.layers.AvgPool3D(strides=(1, 1, 1))
426+
// with tf.GradientTape() as t:
427+
// t.watch(x)
428+
// y = tf.math.reduce_sum(avgpool3D(x))
429+
// print(t.gradient(y, x))
430+
// ```
431+
let expectedGradient = Tensor<Float>(repeating: 0.125, shape: [1, 2, 2, 2, 1])
432+
XCTAssertEqual(computedGradient.0, expectedGradient)
433+
}
434+
379435
func testGlobalAvgPool1D() {
380436
let layer = GlobalAvgPool1D<Float>()
381437
let input = Tensor(shape: [2, 5, 1], scalars: (0..<10).map(Float.init))
@@ -1018,8 +1074,11 @@ final class LayerTests: XCTestCase {
10181074
("testMaxPool3D", testMaxPool3D),
10191075
("testMaxPool3DGradient", testMaxPool3DGradient),
10201076
("testAvgPool1D", testAvgPool1D),
1077+
("testAvgPool1DGradient", testAvgPool1DGradient),
10211078
("testAvgPool2D", testAvgPool2D),
1079+
("testAvgPool2DGradient", testAvgPool2DGradient),
10221080
("testAvgPool3D", testAvgPool3D),
1081+
("testAvgPool3DGradient", testAvgPool3DGradient),
10231082
("testGlobalAvgPool1D", testGlobalAvgPool1D),
10241083
("testGlobalAvgPool1DGradient", testGlobalAvgPool1DGradient),
10251084
("testGlobalAvgPool2D", testGlobalAvgPool2D),

0 commit comments

Comments
 (0)