Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 563aea6

Browse files
mikowalsrxwei
authored andcommitted
add test for Conv2D gradients (#443)
1 parent 1bbd9f0 commit 563aea6

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,52 @@ final class LayerTests: XCTestCase {
110110
XCTAssertEqual(output, expected)
111111
}
112112

113+
func testConv2DGradient() {
114+
let filter = Tensor(shape: [3, 3, 2, 4], scalars: (0..<72).map(Float.init))
115+
let bias = Tensor<Float>(zeros: [4])
116+
let layer = Conv2D<Float>(filter: filter,
117+
bias: bias,
118+
activation: identity,
119+
strides: (2, 2),
120+
padding: .valid)
121+
let input = Tensor(shape: [2, 4, 4, 2], scalars: (0..<64).map(Float.init))
122+
let grads = gradient( at: input, layer) { $1($0).sum() }
123+
// The expected gradients were computed using the following Python code:
124+
// ```
125+
// x = tf.reshape(tf.range(64, dtype=tf.float32), [2, 4, 4, 2])
126+
// filter = tf.reshape(tf.range(72, dtype=tf.float32), [3, 3, 2, 4])
127+
// bias = tf.zeros([4])
128+
// with tf.GradientTape() as t:
129+
// t.watch([x, filter, bias])
130+
// y = tf.math.reduce_sum(tf.nn.conv2d(input=x,
131+
// filters=filter,
132+
// strides=[1, 2, 2, 1],
133+
// data_format="NHWC",
134+
// padding="VALID") + bias)
135+
// grads = t.gradient(y, [x, filter, bias])
136+
// ```
137+
XCTAssertEqual(grads.0,
138+
[[[[ 6, 22], [ 38, 54], [ 70, 86], [ 0, 0]],
139+
[[102, 118], [134, 150], [166, 182], [ 0, 0]],
140+
[[198, 214], [230, 246], [262, 278], [ 0, 0]],
141+
[[ 0, 0], [ 0, 0], [ 0, 0], [ 0, 0]]],
142+
[[[ 6, 22], [ 38, 54], [ 70, 86], [ 0, 0]],
143+
[[102, 118], [134, 150], [166, 182], [ 0, 0]],
144+
[[198, 214], [230, 246], [262, 278], [ 0, 0]],
145+
[[ 0, 0], [ 0, 0], [ 0, 0], [ 0, 0]]]])
146+
XCTAssertEqual(grads.1.filter,
147+
[[[[32, 32, 32, 32], [34, 34, 34, 34]],
148+
[[36, 36, 36, 36], [38, 38, 38, 38]],
149+
[[40, 40, 40, 40], [42, 42, 42, 42]]],
150+
[[[48, 48, 48, 48], [50, 50, 50, 50]],
151+
[[52, 52, 52, 52], [54, 54, 54, 54]],
152+
[[56, 56, 56, 56], [58, 58, 58, 58]]],
153+
[[[64, 64, 64, 64], [66, 66, 66, 66]],
154+
[[68, 68, 68, 68], [70, 70, 70, 70]],
155+
[[72, 72, 72, 72], [74, 74, 74, 74]]]])
156+
XCTAssertEqual(grads.1.bias, [2, 2, 2, 2])
157+
}
158+
113159
func testConv2DDilation() {
114160
// Input shapes. (Data format = NHWC)
115161
let batchSize = 2
@@ -697,6 +743,7 @@ final class LayerTests: XCTestCase {
697743
("testConv1D", testConv1D),
698744
("testConv1DDilation", testConv1DDilation),
699745
("testConv2D", testConv2D),
746+
("testConv2DGradient", testConv2DGradient),
700747
("testConv2DDilation", testConv2DDilation),
701748
("testConv3D", testConv3D),
702749
("testDepthConv2D", testDepthConv2D),

0 commit comments

Comments
 (0)