Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add AvgPool derivative tests #472

Merged
merged 6 commits into from
Aug 25, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,26 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}

func testAvgPool1DGradient() {
let layer = AvgPool1D<Float>(poolSize: 2, stride: 1, padding: .valid)
let x = Tensor(shape: [1, 4, 4], scalars: (0..<16).map(Float.init))
let computedGradient = gradient(at: x, layer) { $1($0).sum() }
// The expected value of the gradient was computed using the following Python code:
// ```
// avgpool1D = tf.keras.layers.AvgPool1D(strides=1)
// with tf.GradientTape() as t:
// t.watch(x)
// y = tf.math.reduce_sum(avgpool1D(x))
// print(t.gradient(y, x))
// ```
let expectedGradient = Tensor<Float>([[
[0.5, 0.5, 0.5, 0.5],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[0.5, 0.5, 0.5, 0.5]]])
XCTAssertEqual(computedGradient.0, expectedGradient)
}

func testAvgPool2D() {
let layer = AvgPool2D<Float>(poolSize: (2, 5), strides: (1, 1), padding: .valid)
let input = Tensor(shape: [1, 2, 5, 1], scalars: (0..<10).map(Float.init))
Expand All @@ -368,6 +388,26 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}

func testAvgPool2DGradient() {
let layer = AvgPool2D<Float>(poolSize: (2, 2), strides: (1, 1), padding: .valid)
let x = Tensor(shape: [1, 4, 4, 2], scalars: (0..<32).map(Float.init))
let computedGradient = gradient(at: x, layer) { $1($0).sum() }
// The expected value of the gradient was computed using the following Python code:
// ```
// avgpool2D = tf.keras.layers.AvgPool2D(strides=(1, 1))
// with tf.GradientTape() as t:
// t.watch(x)
// y = tf.math.reduce_sum(avgpool2D(x))
// print(t.gradient(y, x))
// ```
let expectedGradient = Tensor<Float>([[
[[0.25, 0.25], [0.50, 0.50], [0.50, 0.50], [0.25, 0.25]],
[[0.50, 0.50], [1.00, 1.00], [1.00, 1.00], [0.50, 0.50]],
[[0.50, 0.50], [1.00, 1.00], [1.00, 1.00], [0.50, 0.50]],
[[0.25, 0.25], [0.50, 0.50], [0.50, 0.50], [0.25, 0.25]]]])
XCTAssertEqual(computedGradient.0, expectedGradient)
}

func testAvgPool3D() {
let layer = AvgPool3D<Float>(poolSize: (2, 4, 5), strides: (1, 1, 1), padding: .valid)
let input = Tensor(shape: [1, 2, 4, 5, 1], scalars: (0..<40).map(Float.init))
Expand All @@ -376,6 +416,22 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}

func testAvgPool3DGradient() {
let layer = AvgPool3D<Float>(poolSize: (2, 2, 2), strides: (1, 1, 1), padding: .valid)
let x = Tensor(shape: [1, 2, 2, 2, 1], scalars: (0..<8).map(Float.init))
let computedGradient = gradient(at: x, layer) { $1($0).sum() }
// The expected value of the gradient was computed using the following Python code:
// ```
// avgpool3D = tf.keras.layers.AvgPool3D(strides=(1, 1, 1))
// with tf.GradientTape() as t:
// t.watch(x)
// y = tf.math.reduce_sum(avgpool3D(x))
// print(t.gradient(y, x))
// ```
let expectedGradient = Tensor<Float>(repeating: 0.125, shape: [1, 2, 2, 2, 1])
XCTAssertEqual(computedGradient.0, expectedGradient)
}

func testGlobalAvgPool1D() {
let layer = GlobalAvgPool1D<Float>()
let input = Tensor(shape: [2, 5, 1], scalars: (0..<10).map(Float.init))
Expand Down Expand Up @@ -923,8 +979,11 @@ final class LayerTests: XCTestCase {
("testMaxPool3D", testMaxPool3D),
("testMaxPool3DGradient", testMaxPool3DGradient),
("testAvgPool1D", testAvgPool1D),
("testAvgPool1DGradient", testAvgPool1DGradient),
("testAvgPool2D", testAvgPool2D),
("testAvgPool2DGradient", testAvgPool2DGradient),
("testAvgPool3D", testAvgPool3D),
("testAvgPool3DGradient", testAvgPool3DGradient),
("testGlobalAvgPool1D", testGlobalAvgPool1D),
("testGlobalAvgPool1DGradient", testGlobalAvgPool1DGradient),
("testGlobalAvgPool2D", testGlobalAvgPool2D),
Expand Down