Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Adding conv transpose 1d & 3d #174

Merged
merged 20 commits into from
Nov 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions Sources/TensorFlow/Layers/Convolutional.swift
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,102 @@ public extension Conv3D {
}
}

/// A 1-D transposed convolution layer (e.g. temporal transposed convolution over images).
///
/// This layer creates a convolution filter that is transpose-convolved with the layer input
/// to produce a tensor of outputs.
@frozen
public struct TransposedConv1D<Scalar: TensorFlowFloatingPoint>: Layer {
/// The 1-D convolution kernel.
public var filter: Tensor<Scalar>
/// The bias vector.
public var bias: Tensor<Scalar>
/// The element-wise activation function.
@noDerivative public let activation: Activation
/// The strides of the sliding window for spatial dimensions.
@noDerivative public let stride: Int
/// The padding algorithm for convolution.
@noDerivative public let padding: Padding
/// The paddingIndex property allows us to handle computation based on padding.
@noDerivative public let paddingIndex: Int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a doc comment for this property.


/// The element-wise activation function type.
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>

/// Creates a `TransposedConv1D` layer with the specified filter, bias,
/// activation function, strides, and padding.
///
/// - Parameters:
/// - filter: The 3-D convolution kernel.
/// - bias: The bias vector.
/// - activation: The element-wise activation function.
/// - strides: The strides of the sliding window for spatial dimensions.
/// - padding: The padding algorithm for convolution.
public init(
filter: Tensor<Scalar>,
bias: Tensor<Scalar>,
activation: @escaping Activation = identity,
stride: Int = 1,
padding: Padding = .valid
) {
self.filter = filter
self.bias = bias
self.activation = activation
self.stride = stride
self.padding = padding
self.paddingIndex = padding == .same ? 0 : 1
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let batchSize = input.shape[0]
let w = (input.shape[1] - (1 * paddingIndex)) *
stride + (filter.shape[0] * paddingIndex)
let c = filter.shape[2]
let newShape = Tensor<Int32>([Int32(batchSize), 1, Int32(w), Int32(c)])
return activation(conv2DBackpropInput(
input.expandingShape(at: 1),
shape: newShape,
filter: filter.expandingShape(at: 0),
strides: (1, 1, stride, 1),
padding: padding) + bias)
}
}

public extension TransposedConv1D {
/// Creates a `TransposedConv1D` layer with the specified filter shape, strides, padding, and
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
/// initialization with the specified generator. The bias vector is initialized with zeros.
///
/// - Parameters:
/// - filterShape: The shape of the 3-D convolution kernel.
/// - strides: The strides of the sliding window for spatial dimensions.
/// - padding: The padding algorithm for convolution.
/// - activation: The element-wise activation function.
/// - generator: The random number generator for initialization.
init(
filterShape: (Int, Int, Int),
stride: Int = 1,
padding: Padding = .valid,
activation: @escaping Activation = identity,
filterInitializer: ParameterInitializer<Scalar> = glorotUniform(),
biasInitializer: ParameterInitializer<Scalar> = zeros()
) {
let filterTensorShape = TensorShape([
filterShape.0, filterShape.1, filterShape.2])
self.init(
filter: filterInitializer(filterTensorShape),
bias: biasInitializer([filterShape.2]),
activation: activation,
stride: stride,
padding: padding)
}
}

/// A 2-D transposed convolution layer (e.g. spatial transposed convolution over images).
///
/// This layer creates a convolution filter that is transpose-convolved with the layer input
Expand Down Expand Up @@ -449,6 +545,107 @@ public extension TransposedConv2D {
}
}


/// A 3-D transposed convolution layer (e.g. spatial transposed convolution over images).
///
/// This layer creates a convolution filter that is transpose-convolved with the layer input
/// to produce a tensor of outputs.
@frozen
public struct TransposedConv3D<Scalar: TensorFlowFloatingPoint>: Layer {
/// The 5-D convolution kernel.
public var filter: Tensor<Scalar>
/// The bias vector.
public var bias: Tensor<Scalar>
/// The element-wise activation function.
@noDerivative public let activation: Activation
/// The strides of the sliding window for spatial dimensions.
@noDerivative public let strides: (Int, Int, Int)
/// The padding algorithm for convolution.
@noDerivative public let padding: Padding
/// The paddingIndex property allows us to handle computation based on padding.
@noDerivative public let paddingIndex: Int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a doc comment for this property.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxwei this property isn't really exposed to the users, it's a temporary property which allows us to handle computation based on what kind of padding it is. we currently support .same and .valid.
Do we still document this?


/// The element-wise activation function type.
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>

/// Creates a `TransposedConv3D` layer with the specified filter, bias,
/// activation function, strides, and padding.
///
/// - Parameters:
/// - filter: The 5-D convolution kernel.
/// - bias: The bias vector.
/// - activation: The element-wise activation function.
/// - strides: The strides of the sliding window for spatial dimensions.
/// - padding: The padding algorithm for convolution.
public init(
filter: Tensor<Scalar>,
bias: Tensor<Scalar>,
activation: @escaping Activation = identity,
strides: (Int, Int, Int) = (1, 1, 1),
padding: Padding = .valid
) {
self.filter = filter
self.bias = bias
self.activation = activation
self.strides = strides
self.padding = padding
self.paddingIndex = padding == .same ? 0 : 1
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let batchSize = input.shape[0]
let w = (input.shape[1] - (1 * paddingIndex)) *
strides.0 + (filter.shape[0] * paddingIndex)
let h = (input.shape[2] - (1 * paddingIndex)) *
strides.1 + (filter.shape[1] * paddingIndex)
let d = (input.shape[3] - (1 * paddingIndex)) *
strides.2 + (filter.shape[2] * paddingIndex)
let c = filter.shape[3]
let newShape = Tensor<Int32>([Int32(batchSize), Int32(w), Int32(h), Int32(d), Int32(c)])
return activation(conv3DBackpropInput(
input,
shape: newShape,
filter: filter,
strides: (1, strides.0, strides.1, strides.2, 1),
padding: padding) + bias)
}
}

public extension TransposedConv3D {
/// Creates a `TransposedConv3D` layer with the specified filter shape, strides, padding, and
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
/// initialization with the specified generator. The bias vector is initialized with zeros.
///
/// - Parameters:
/// - filterShape: The shape of the 5-D convolution kernel.
/// - strides: The strides of the sliding window for spatial dimensions.
/// - padding: The padding algorithm for convolution.
/// - activation: The element-wise activation function.
/// - generator: The random number generator for initialization.
init(
filterShape: (Int, Int, Int, Int, Int),
strides: (Int, Int, Int) = (1, 1, 1),
padding: Padding = .valid,
activation: @escaping Activation = identity,
filterInitializer: ParameterInitializer<Scalar> = glorotUniform(),
biasInitializer: ParameterInitializer<Scalar> = zeros()
) {
let filterTensorShape = TensorShape([
filterShape.0, filterShape.1, filterShape.2, filterShape.3, filterShape.4])
self.init(
filter: filterInitializer(filterTensorShape),
bias: biasInitializer([filterShape.4]),
activation: activation,
strides: strides,
padding: padding)
}
}

/// A 2-D depthwise convolution layer.
///
/// This layer creates seperable convolution filters that are convolved with the layer input to produce a
Expand Down
27 changes: 27 additions & 0 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,18 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(grads.1.bias, [4, 4, 4, 4])
}

func testTransposedConv1D() {
let filter = Tensor(shape: [4, 1, 1], scalars: (0..<4).map(Float.init))
let bias = Tensor<Float>([8])
let layer = TransposedConv1D(filter: filter, bias: bias, activation: identity,
stride: 1, padding: .same)
let input = Tensor(shape: [1, 4, 1], scalars: (0..<4).map(Float.init))
let output = layer.inferring(from: input)
let expected = Tensor<Float>(shape: [1, 1, 4, 1],
scalars: [8, 9, 12, 18])
XCTAssertEqual(output, expected)
}

func testTransposedConv2D() {
let filter = Tensor(shape: [4, 2, 1, 1], scalars: (0..<8).map(Float.init))
let bias = Tensor<Float>([8])
Expand All @@ -314,6 +326,19 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}


func testTransposedConv3D() {
let filter = Tensor(shape: [2, 2, 2, 1, 1], scalars: (0..<8).map(Float.init))
let bias = Tensor<Float>([8])
let layer = TransposedConv3D(filter: filter, bias: bias, activation: identity,
strides: (1, 1, 1), padding: .same)
let input = Tensor(shape: [1, 2, 2, 2, 1], scalars: (0..<8).map(Float.init))
let output = layer.inferring(from: input)
let expected = Tensor<Float>(shape: [1, 2, 2, 2, 1],
scalars: [8, 8, 8, 12, 8, 16, 24, 64])
XCTAssertEqual(output, expected)
}

func testSeparableConv1D() {
let depthwiseFilter = Tensor(shape: [2, 2, 2], scalars: (0..<8).map(Float.init))
let pointwiseFilter = Tensor(shape: [1, 4, 1], scalars: (0..<4).map(Float.init))
Expand Down Expand Up @@ -1330,7 +1355,9 @@ final class LayerTests: XCTestCase {
("testConv2DDilation", testConv2DDilation),
("testConv3D", testConv3D),
("testConv3DGradient", testConv3DGradient),
("testTransposedConv1D", testTransposedConv1D),
("testTransposedConv2D", testTransposedConv2D),
("testTransposedConv3D", testTransposedConv3D),
("testDepthwiseConv2D", testDepthwiseConv2D),
("testDepthwiseConv2DGradient", testDepthwiseConv2DGradient),
("testSeparableConv1D", testSeparableConv1D),
Expand Down