Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Adding seperable conv2d gradient test #513

Merged
merged 1 commit into from
Sep 24, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,43 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}

func testSeparableConv2DGradient() {
let depthwiseFilter = Tensor(shape: [2, 1, 2, 2], scalars: (0..<8).map(Float.init))
let pointwiseFilter = Tensor(shape: [1, 1, 4, 1], scalars: (0..<4).map(Float.init))
let bias = Tensor<Float>([1, 1])
let layer = SeparableConv2D<Float>(depthwiseFilter: depthwiseFilter,
pointwiseFilter: pointwiseFilter,
bias: bias,
activation: identity,
strides: (1, 1),
padding: .same)
let input = Tensor(shape: [2, 1, 2, 2], scalars: (0..<8).map(Float.init))
let grads = gradient(at: input, layer) { $1($0).sum() }
// The expected value of the gradient was computed using the following Python code:
// ```
// import tensorflow as tf
// x = tf.reshape(tf.range(8, dtype=tf.float32), [2, 1, 2, 2])
// depthwiseFilter = tf.reshape(tf.range(8, dtype=tf.float32), [2, 1, 2, 2])
// pointwiseFilter = tf.reshape(tf.range(4, dtype=tf.float32), [1, 1, 4, 1])
// bias = tf.ones([2])
// with tf.GradientTape() as tape:
// tape.watch([x, depthwiseFilter, pointwiseFilter, bias])
// y = tf.math.reduce_sum(tf.nn.separable_conv2D(input,
// depthwiseFilter,
// pointwiseFilter
// strides=[1, 1, 1, 1],
// padding="SAME") + bias)
// print(tape.gradient(y, [x, depthwiseFilter, pointwiseFilter, bias])
// ```
XCTAssertEqual(grads.0,
[[[[ 2.0, 26.0], [ 2.0, 26.0]]],
[[[ 2.0, 26.0], [ 2.0, 26.0]]]])
XCTAssertEqual(grads.1.depthwiseFilter,
[[[[ 0.0, 24.0], [64.0, 96.0]]],
[[[ 0.0, 0.0], [ 0.0, 0.0]]]])
XCTAssertEqual(grads.1.bias, [4.0, 4.0])
}

func testZeroPadding1D() {
let input = Tensor<Float>(shape: [1, 3, 1], scalars: [0.0, 1.0, 2.0])
let layer = ZeroPadding1D<Float>(padding: 2)
Expand Down Expand Up @@ -1233,6 +1270,7 @@ final class LayerTests: XCTestCase {
("testDepthwiseConv2DGradient", testDepthwiseConv2DGradient),
("testSeparableConv1D", testSeparableConv1D),
("testSeparableConv2D", testSeparableConv2D),
("testSeparableConv2DGradient", testSeparableConv2DGradient),
("testZeroPadding1D", testZeroPadding1D),
("testZeroPadding1DGradient", testZeroPadding1DGradient),
("testZeroPadding2D", testZeroPadding2D),
Expand Down