Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add Conv3D gradient test and vjp fixes #460

Merged
merged 14 commits into from
Sep 24, 2019
28 changes: 11 additions & 17 deletions Sources/TensorFlow/Operators/NN.swift
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,10 @@ func _vjpConv3D<Scalar: TensorFlowFloatingPoint>(
let value = conv3D(x, filter: filter, strides: strides,
padding: padding)
return (value, { v in
return (
conv3DBackpropInput(v, shape: x.shapeTensor, filter: filter,
strides: strides, padding: padding),
conv3DBackpropFilter(v, input: x, filterSizes: filter.shapeTensor,
strides: strides, padding: padding)
)
(conv3DBackpropInput(v, shape: x.shapeTensor, filter: filter,
strides: strides, padding: padding),
conv3DBackpropFilter(v, input: x, filterSizes: filter.shapeTensor,
strides: strides, padding: padding))
})
}

Expand Down Expand Up @@ -268,11 +266,9 @@ func _vjpConv3DBackpropInput<Scalar: TensorFlowFloatingPoint>(
let value = conv3DBackpropInput(x, shape: shape, filter: filter, strides: strides,
padding: padding)
return (value, { v in
return (
conv3DBackpropFilter(x, input: v, filterSizes: shape, strides: strides,
padding: padding),
conv3D(v, filter: filter, strides: strides, padding: padding)
)
(conv3D(v, filter: filter, strides: strides, padding: padding),
conv3DBackpropFilter(x, input: v, filterSizes: filter.shapeTensor, strides: strides,
padding: padding))
})
}

Expand All @@ -287,7 +283,7 @@ func conv3DBackpropFilter<Scalar: TensorFlowFloatingPoint>(
padding: Padding = .valid
) -> Tensor<Scalar> {
return Raw.conv3DBackpropFilterV2(
x,
input,
filterSizes: filterSizes,
outBackprop: x,
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2),
Expand All @@ -306,11 +302,9 @@ func _vjpConv3DBackpropFilter<Scalar: TensorFlowFloatingPoint>(
let value = conv3DBackpropFilter(x, input: input, filterSizes: filterSizes,
strides: strides, padding: padding)
return (value, { v in
return (
conv3DBackpropInput(x, shape: filterSizes, filter: v, strides: strides,
padding: padding),
conv3D(input, filter: v, strides: strides, padding: padding)
)
(conv3D(input, filter: v, strides: strides, padding: padding),
conv3DBackpropInput(x, shape: x.shapeTensor, filter: v, strides: strides,
padding: padding))
})
}

Expand Down
50 changes: 50 additions & 0 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,55 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}

func testConv3DGradient() {
let filter = Tensor(shape: [1, 4, 4, 1, 1], scalars: (0..<16).map(Float.init))
let bias = Tensor<Float>(ones: [2])
let layer = Conv3D(filter: filter,
bias: bias,
activation: identity,
strides: (2, 2, 2),
padding: .same)
let input = Tensor(shape: [1, 4, 4, 4, 1], scalars: (0..<64).map(Float.init))
let grads = gradient(at: input, layer) { $1($0).sum() }
// The expected value of the gradient was computed using the following Python code:
// ```
// import tensorflow as tf
// x = tf.reshape(tf.range(64, dtype=tf.float32), [1, 4, 4, 4, 1])
// filter = tf.reshape(tf.range(72, dtype=tf.float32), [1, 4, 4, 1, 1])
// bias = tf.ones([2])
// with tf.GradientTape() as tape:
// tape.watch([x, filter, bias])
// y = tf.math.reduce_sum(tf.nn.conv3d(input=x,
// filters=filter,
// strides=[1, 2, 2, 2, 1],
// padding="SAME") + bias)
// print(tape.gradient(y, [x, filter, bias]))
// ```
XCTAssertEqual(grads.0,
[[[[[10.0], [20.0], [24.0], [12.0]],
[[20.0], [40.0], [48.0], [24.0]],
[[36.0], [72.0], [80.0], [40.0]],
[[18.0], [36.0], [40.0], [20.0]]],
[[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]]],
[[[10.0], [20.0], [24.0], [12.0]],
[[20.0], [40.0], [48.0], [24.0]],
[[36.0], [72.0], [80.0], [40.0]],
[[18.0], [36.0], [40.0], [20.0]]],
[[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]]]]])
XCTAssertEqual(grads.1.filter,
[[[[[ 84.0]], [[168.0]], [[176.0]], [[ 88.0]]],
[[[168.0]], [[336.0]], [[352.0]], [[176.0]]],
[[[200.0]], [[400.0]], [[416.0]], [[208.0]]],
[[[100.0]], [[200.0]], [[208.0]], [[104.0]]]]])
XCTAssertEqual(grads.1.bias, [8.0, 8.0])
}

func testDepthwiseConv2D() {
let filter = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
let bias = Tensor<Float>([1, 2, 3, 4])
Expand Down Expand Up @@ -1237,6 +1286,7 @@ final class LayerTests: XCTestCase {
("testConv2DGradient", testConv2DGradient),
("testConv2DDilation", testConv2DDilation),
("testConv3D", testConv3D),
("testConv3DGradient", testConv3DGradient),
("testDepthwiseConv2D", testDepthwiseConv2D),
("testDepthwiseConv2DGradient", testDepthwiseConv2DGradient),
("testSeparableConv1D", testSeparableConv1D),
Expand Down