Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 35dfddf

Browse files
Shashi456marcrasi
authored andcommitted
Fix Transposed Conv2d error & add test (#288)
1 parent b7ba0d5 commit 35dfddf

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ public struct TransposedConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
377377
///
378378
/// - Parameters:
379379
/// - filter: A 4-D tensor of shape
380-
/// `[width, height, input channel count, output channel count]`.
380+
/// `[height, width, output channel count, input channel count]`.
381381
/// - bias: The bias tensor of shape `[output channel count]`.
382382
/// - activation: The element-wise activation function.
383383
/// - strides: The strides of the sliding window for spatial dimensions.
@@ -404,12 +404,12 @@ public struct TransposedConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
404404
@differentiable
405405
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
406406
let batchSize = input.shape[0]
407-
let w = (input.shape[1] - (1 * paddingIndex)) *
407+
let h = (input.shape[1] - (1 * paddingIndex)) *
408408
strides.0 + (filter.shape[0] * paddingIndex)
409-
let h = (input.shape[2] - (1 * paddingIndex)) *
409+
let w = (input.shape[2] - (1 * paddingIndex)) *
410410
strides.1 + (filter.shape[1] * paddingIndex)
411411
let c = filter.shape[2]
412-
let newShape = Tensor<Int32>([Int32(batchSize), Int32(w), Int32(h), Int32(c)])
412+
let newShape = Tensor<Int32>([Int32(batchSize), Int32(h), Int32(w), Int32(c)])
413413
return activation(conv2DBackpropInput(
414414
input,
415415
shape: newShape,

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,18 @@ final class LayerTests: XCTestCase {
302302
XCTAssertEqual(grads.1.bias, [4, 4, 4, 4])
303303
}
304304

305+
func testTransposedConv2D() {
306+
let filter = Tensor(shape: [4, 2, 1, 1], scalars: (0..<8).map(Float.init))
307+
let bias = Tensor<Float>([8])
308+
let layer = TransposedConv2D(filter: filter, bias: bias, activation: identity,
309+
strides: (1, 1), padding: .same)
310+
let input = Tensor(shape: [1, 4, 2, 1], scalars: (0..<8).map(Float.init))
311+
let output = layer.inferring(from: input)
312+
let expected = Tensor<Float>(shape: [1, 4, 2, 1],
313+
scalars: [8, 12, 12, 28, 24, 64, 48, 112])
314+
XCTAssertEqual(output, expected)
315+
}
316+
305317
func testSeparableConv1D() {
306318
let depthwiseFilter = Tensor(shape: [2, 2, 2], scalars: (0..<8).map(Float.init))
307319
let pointwiseFilter = Tensor(shape: [1, 4, 1], scalars: (0..<4).map(Float.init))
@@ -1318,6 +1330,7 @@ final class LayerTests: XCTestCase {
13181330
("testConv2DDilation", testConv2DDilation),
13191331
("testConv3D", testConv3D),
13201332
("testConv3DGradient", testConv3DGradient),
1333+
("testTransposedConv2D", testTransposedConv2D),
13211334
("testDepthwiseConv2D", testDepthwiseConv2D),
13221335
("testDepthwiseConv2DGradient", testDepthwiseConv2DGradient),
13231336
("testSeparableConv1D", testSeparableConv1D),

0 commit comments

Comments
 (0)