Skip to content

Commit 1b4f825

Browse files
Shashi456Ricardo Ocampo
authored andcommitted
Add Conv3D gradient test and vjp fixes (tensorflow#460)
1 parent b6f9cef commit 1b4f825

File tree

2 files changed

+61
-17
lines changed

2 files changed

+61
-17
lines changed

Sources/TensorFlow/Operators/NN.swift

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,10 @@ func _vjpConv3D<Scalar: TensorFlowFloatingPoint>(
229229
let value = conv3D(x, filter: filter, strides: strides,
230230
padding: padding)
231231
return (value, { v in
232-
return (
233-
conv3DBackpropInput(v, shape: x.shapeTensor, filter: filter,
234-
strides: strides, padding: padding),
235-
conv3DBackpropFilter(v, input: x, filterSizes: filter.shapeTensor,
236-
strides: strides, padding: padding)
237-
)
232+
(conv3DBackpropInput(v, shape: x.shapeTensor, filter: filter,
233+
strides: strides, padding: padding),
234+
conv3DBackpropFilter(v, input: x, filterSizes: filter.shapeTensor,
235+
strides: strides, padding: padding))
238236
})
239237
}
240238

@@ -268,11 +266,9 @@ func _vjpConv3DBackpropInput<Scalar: TensorFlowFloatingPoint>(
268266
let value = conv3DBackpropInput(x, shape: shape, filter: filter, strides: strides,
269267
padding: padding)
270268
return (value, { v in
271-
return (
272-
conv3DBackpropFilter(x, input: v, filterSizes: shape, strides: strides,
273-
padding: padding),
274-
conv3D(v, filter: filter, strides: strides, padding: padding)
275-
)
269+
(conv3D(v, filter: filter, strides: strides, padding: padding),
270+
conv3DBackpropFilter(x, input: v, filterSizes: filter.shapeTensor, strides: strides,
271+
padding: padding))
276272
})
277273
}
278274

@@ -287,7 +283,7 @@ func conv3DBackpropFilter<Scalar: TensorFlowFloatingPoint>(
287283
padding: Padding = .valid
288284
) -> Tensor<Scalar> {
289285
return Raw.conv3DBackpropFilterV2(
290-
x,
286+
input,
291287
filterSizes: filterSizes,
292288
outBackprop: x,
293289
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2),
@@ -306,11 +302,9 @@ func _vjpConv3DBackpropFilter<Scalar: TensorFlowFloatingPoint>(
306302
let value = conv3DBackpropFilter(x, input: input, filterSizes: filterSizes,
307303
strides: strides, padding: padding)
308304
return (value, { v in
309-
return (
310-
conv3DBackpropInput(x, shape: filterSizes, filter: v, strides: strides,
311-
padding: padding),
312-
conv3D(input, filter: v, strides: strides, padding: padding)
313-
)
305+
(conv3D(input, filter: v, strides: strides, padding: padding),
306+
conv3DBackpropInput(x, shape: x.shapeTensor, filter: v, strides: strides,
307+
padding: padding))
314308
})
315309
}
316310

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,55 @@ final class LayerTests: XCTestCase {
198198
XCTAssertEqual(output, expected)
199199
}
200200

201+
func testConv3DGradient() {
202+
let filter = Tensor(shape: [1, 4, 4, 1, 1], scalars: (0..<16).map(Float.init))
203+
let bias = Tensor<Float>(ones: [2])
204+
let layer = Conv3D(filter: filter,
205+
bias: bias,
206+
activation: identity,
207+
strides: (2, 2, 2),
208+
padding: .same)
209+
let input = Tensor(shape: [1, 4, 4, 4, 1], scalars: (0..<64).map(Float.init))
210+
let grads = gradient(at: input, layer) { $1($0).sum() }
211+
// The expected value of the gradient was computed using the following Python code:
212+
// ```
213+
// import tensorflow as tf
214+
// x = tf.reshape(tf.range(64, dtype=tf.float32), [1, 4, 4, 4, 1])
215+
// filter = tf.reshape(tf.range(72, dtype=tf.float32), [1, 4, 4, 1, 1])
216+
// bias = tf.ones([2])
217+
// with tf.GradientTape() as tape:
218+
// tape.watch([x, filter, bias])
219+
// y = tf.math.reduce_sum(tf.nn.conv3d(input=x,
220+
// filters=filter,
221+
// strides=[1, 2, 2, 2, 1],
222+
// padding="SAME") + bias)
223+
// print(tape.gradient(y, [x, filter, bias]))
224+
// ```
225+
XCTAssertEqual(grads.0,
226+
[[[[[10.0], [20.0], [24.0], [12.0]],
227+
[[20.0], [40.0], [48.0], [24.0]],
228+
[[36.0], [72.0], [80.0], [40.0]],
229+
[[18.0], [36.0], [40.0], [20.0]]],
230+
[[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
231+
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
232+
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
233+
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]]],
234+
[[[10.0], [20.0], [24.0], [12.0]],
235+
[[20.0], [40.0], [48.0], [24.0]],
236+
[[36.0], [72.0], [80.0], [40.0]],
237+
[[18.0], [36.0], [40.0], [20.0]]],
238+
[[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
239+
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
240+
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]],
241+
[[ 0.0], [ 0.0], [ 0.0], [ 0.0]]]]])
242+
XCTAssertEqual(grads.1.filter,
243+
[[[[[ 84.0]], [[168.0]], [[176.0]], [[ 88.0]]],
244+
[[[168.0]], [[336.0]], [[352.0]], [[176.0]]],
245+
[[[200.0]], [[400.0]], [[416.0]], [[208.0]]],
246+
[[[100.0]], [[200.0]], [[208.0]], [[104.0]]]]])
247+
XCTAssertEqual(grads.1.bias, [8.0, 8.0])
248+
}
249+
201250
func testDepthwiseConv2D() {
202251
let filter = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
203252
let bias = Tensor<Float>([1, 2, 3, 4])
@@ -1274,6 +1323,7 @@ final class LayerTests: XCTestCase {
12741323
("testConv2DGradient", testConv2DGradient),
12751324
("testConv2DDilation", testConv2DDilation),
12761325
("testConv3D", testConv3D),
1326+
("testConv3DGradient", testConv3DGradient),
12771327
("testDepthwiseConv2D", testDepthwiseConv2D),
12781328
("testDepthwiseConv2DGradient", testDepthwiseConv2DGradient),
12791329
("testSeparableConv1D", testSeparableConv1D),

0 commit comments

Comments
 (0)