Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Fix Transposed Conv2d error & add test #288

Merged
merged 20 commits into from
Nov 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Sources/TensorFlow/Layers/Convolutional.swift
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ public struct TransposedConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
///
/// - Parameters:
/// - filter: A 4-D tensor of shape
/// `[width, height, input channel count, output channel count]`.
/// `[height, width, output channel count, input channel count]`.
/// - bias: The bias tensor of shape `[output channel count]`.
/// - activation: The element-wise activation function.
/// - strides: The strides of the sliding window for spatial dimensions.
Expand All @@ -404,12 +404,12 @@ public struct TransposedConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let batchSize = input.shape[0]
let w = (input.shape[1] - (1 * paddingIndex)) *
let h = (input.shape[1] - (1 * paddingIndex)) *
strides.0 + (filter.shape[0] * paddingIndex)
let h = (input.shape[2] - (1 * paddingIndex)) *
let w = (input.shape[2] - (1 * paddingIndex)) *
strides.1 + (filter.shape[1] * paddingIndex)
let c = filter.shape[2]
let newShape = Tensor<Int32>([Int32(batchSize), Int32(w), Int32(h), Int32(c)])
let newShape = Tensor<Int32>([Int32(batchSize), Int32(h), Int32(w), Int32(c)])
return activation(conv2DBackpropInput(
input,
shape: newShape,
Expand Down
13 changes: 13 additions & 0 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,18 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(grads.1.bias, [4, 4, 4, 4])
}

func testTransposedConv2D() {
let filter = Tensor(shape: [4, 2, 1, 1], scalars: (0..<8).map(Float.init))
let bias = Tensor<Float>([8])
let layer = TransposedConv2D(filter: filter, bias: bias, activation: identity,
strides: (1, 1), padding: .same)
let input = Tensor(shape: [1, 4, 2, 1], scalars: (0..<8).map(Float.init))
let output = layer.inferring(from: input)
let expected = Tensor<Float>(shape: [1, 4, 2, 1],
scalars: [8, 12, 12, 28, 24, 64, 48, 112])
XCTAssertEqual(output, expected)
}

func testSeparableConv1D() {
let depthwiseFilter = Tensor(shape: [2, 2, 2], scalars: (0..<8).map(Float.init))
let pointwiseFilter = Tensor(shape: [1, 4, 1], scalars: (0..<4).map(Float.init))
Expand Down Expand Up @@ -1318,6 +1330,7 @@ final class LayerTests: XCTestCase {
("testConv2DDilation", testConv2DDilation),
("testConv3D", testConv3D),
("testConv3DGradient", testConv3DGradient),
("testTransposedConv2D", testTransposedConv2D),
("testDepthwiseConv2D", testDepthwiseConv2D),
("testDepthwiseConv2DGradient", testDepthwiseConv2DGradient),
("testSeparableConv1D", testSeparableConv1D),
Expand Down