Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1bbd9f0

Browse files
jon-towrxwei
authored andcommitted
Add MaxPool derivative tests (#444)
#402
1 parent 481c934 commit 1bbd9f0

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,26 @@ final class LayerTests: XCTestCase {
214214
XCTAssertEqual(output, expected)
215215
}
216216

217+
func testMaxPool1DGradient() {
218+
let layer = MaxPool1D<Float>(poolSize: 2, stride: 1, padding: .valid)
219+
let x = Tensor<Float>(shape: [1, 4, 4], scalars: (0..<16).map(Float.init))
220+
let computedGradient = gradient(at: x, layer) { $1($0).sum() }
221+
// The expected value of the gradient was computed using the following Python code:
222+
// ```
223+
// maxpool1D = tf.keras.layers.MaxPool1D()
224+
// with tf.GradientTape() as t:
225+
// t.watch(x)
226+
// y = tf.math.reduce_sum(maxpool1D(x))
227+
// print(t.gradient(y, x))
228+
// ```
229+
let expectedGradient = Tensor<Float>([[
230+
[0, 0, 0, 0],
231+
[1, 1, 1, 1],
232+
[1, 1, 1, 1],
233+
[1, 1, 1, 1]]])
234+
XCTAssertEqual(computedGradient.0, expectedGradient)
235+
}
236+
217237
func testMaxPool2D() {
218238
let layer = MaxPool2D<Float>(poolSize: (2, 2), strides: (1, 1), padding: .valid)
219239
let input = Tensor(shape: [1, 2, 2, 1], scalars: (0..<4).map(Float.init))
@@ -222,6 +242,26 @@ final class LayerTests: XCTestCase {
222242
XCTAssertEqual(output, expected)
223243
}
224244

245+
func testMaxPool2DGradient() {
246+
let layer = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2), padding: .valid)
247+
let x = Tensor(shape: [1, 4, 4, 1], scalars: (0..<16).map(Float.init))
248+
let computedGradient = gradient(at: x, layer) { $1($0).sum() }
249+
// The expected value of the gradient was computed using the following Python code:
250+
// ```
251+
// maxpool2D = tf.keras.layers.MaxPool2D(strides=(2, 2))
252+
// with tf.GradientTape() as t:
253+
// t.watch(x)
254+
// y = tf.math.reduce_sum(maxpool2D(x))
255+
// print(t.gradient(y, x))
256+
// ```
257+
let expectedGradient = Tensor<Float>([[
258+
[[0], [0], [0], [0]],
259+
[[0], [1], [0], [1]],
260+
[[0], [0], [0], [0]],
261+
[[0], [1], [0], [1]]]])
262+
XCTAssertEqual(computedGradient.0, expectedGradient)
263+
}
264+
225265
func testMaxPool3D() {
226266
let layer = MaxPool3D<Float>(poolSize: (2, 2, 2), strides: (1, 1, 1), padding: .valid)
227267
let input = Tensor(shape: [1, 2, 2, 2, 1], scalars: (0..<8).map(Float.init))
@@ -230,6 +270,26 @@ final class LayerTests: XCTestCase {
230270
XCTAssertEqual(output, expected)
231271
}
232272

273+
func testMaxPool3DGradient(){
274+
let layer = MaxPool3D<Float>(poolSize: (2, 2, 2), strides: (1, 1, 1), padding: .valid)
275+
let x = Tensor(shape: [1, 2, 2, 2, 1], scalars: (0..<8).map(Float.init))
276+
let computedGradient = gradient(at: x, layer) { $1($0).sum() }
277+
// The expected value of the gradient was computed using the following Python code:
278+
// ```
279+
// maxpool3D = tf.keras.layers.MaxPool3D(strides=(1, 1, 1))
280+
// with tf.GradientTape() as t:
281+
// t.watch(x)
282+
// y = tf.math.reduce_sum(maxpool3D(x))
283+
// print(t.gradient(y, x))
284+
// ```
285+
let expectedGradient = Tensor<Float>([[
286+
[[[0], [0]],
287+
[[0], [0]]],
288+
[[[0], [0]],
289+
[[0], [1]]]]])
290+
XCTAssertEqual(computedGradient.0, expectedGradient)
291+
}
292+
233293
func testAvgPool1D() {
234294
let layer = AvgPool1D<Float>(poolSize: 3, stride: 1, padding: .valid)
235295
let input = Tensor<Float>([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2)
@@ -645,8 +705,11 @@ final class LayerTests: XCTestCase {
645705
("testZeroPadding2D", testZeroPadding2D),
646706
("testZeroPadding3D", testZeroPadding3D),
647707
("testMaxPool1D", testMaxPool1D),
708+
("testMaxPool1DGradient", testMaxPool1DGradient),
648709
("testMaxPool2D", testMaxPool2D),
710+
("testMaxPool2DGradient", testMaxPool2DGradient),
649711
("testMaxPool3D", testMaxPool3D),
712+
("testMaxPool3DGradient", testMaxPool3DGradient),
650713
("testAvgPool1D", testAvgPool1D),
651714
("testAvgPool2D", testAvgPool2D),
652715
("testAvgPool3D", testAvgPool3D),

0 commit comments

Comments
 (0)