Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add Separable Conv2D Layer #362

Merged
merged 6 commits into from
Aug 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 110 additions & 2 deletions Sources/TensorFlow/Layers/Convolutional.swift
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
@noDerivative public let padding: Padding
/// The dilation factor for spatial dimensions.
@noDerivative public let dilations: (Int, Int)

/// The element-wise activation function type.
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>

Expand Down Expand Up @@ -470,7 +470,7 @@ public struct DepthwiseConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// The element-wise activation function type.
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>

/// Creates a `DepthwiseConv2D` layer with the specified filter, bias, activation function,
/// Creates a `DepthwiseConv2D` layer with the specified filter, bias, activation function,
/// strides, and padding.
///
/// - Parameters:
Expand Down Expand Up @@ -631,3 +631,111 @@ public struct ZeroPadding3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer
return input.padded(forSizes: [padding.0, padding.1, padding.2])
}
}

/// A 2-D Separable convolution layer.
///
/// This layer performs a depthwise convolution that acts separately on channels followed by
/// a pointwise convolution that mixes channels.
@frozen
public struct SeparableConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// The 4-D depthwise convolution kernel.
public var depthwiseFilter: Tensor<Scalar>
/// The 4-D pointwise convolution kernel.
public var pointwiseFilter: Tensor<Scalar>
/// The bias vector.
public var bias: Tensor<Scalar>
/// The element-wise activation function.
@noDerivative public let activation: Activation
/// The strides of the sliding window for spatial dimensions.
@noDerivative public let strides: (Int, Int)
/// The padding algorithm for convolution.
@noDerivative public let padding: Padding

/// The element-wise activation function type.
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>

/// Creates a `SeparableConv2D` layer with the specified depthwise and pointwise filter,
/// bias, activation function, strides, and padding.
///
/// - Parameters:
/// - depthwiseFilter: The 4-D depthwise convolution kernel
/// `[filter height, filter width, input channels count, channel multiplier]`.
/// - pointwiseFilter: The 4-D pointwise convolution kernel
/// `[1, 1, channel multiplier * input channels count, output channels count]`.
/// - bias: The bias vector.
/// - activation: The element-wise activation function.
/// - strides: The strides of the sliding window for spatial dimensions.
/// - padding: The padding algorithm for convolution.
public init(
depthwiseFilter: Tensor<Scalar>,
pointwiseFilter: Tensor<Scalar>,
bias: Tensor<Scalar>,
activation: @escaping Activation = identity,
strides: (Int, Int) = (1, 1),
padding: Padding = .valid
) {
self.depthwiseFilter = depthwiseFilter
self.pointwiseFilter = pointwiseFilter
self.bias = bias
self.activation = activation
self.strides = strides
self.padding = padding
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let depthwise = depthwiseConv2D(
input,
filter: depthwiseFilter,
strides: (1, strides.0, strides.1, 1),
padding: padding)
return activation(conv2D(
depthwise,
filter: pointwiseFilter,
strides: (1, 1, 1, 1),
padding: padding,
dilations: (1, 1, 1, 1)) + bias)
}
}

public extension SeparableConv2D {
/// Creates a `SeparableConv2D` layer with the specified depthwise and pointwise filter shape,
/// strides, padding, and element-wise activation function.
///
/// - Parameters:
/// - depthwiseFilterShape: The shape of the 4-D depthwise convolution kernel.
/// - pointwiseFilterShape: The shape of the 4-D pointwise convolution kernel.
/// - strides: The strides of the sliding window for spatial/spatio-temporal dimensions.
/// - padding: The padding algorithm for convolution.
/// - activation: The element-wise activation function.
/// - filterInitializer: Initializer to use for the filter parameters.
/// - biasInitializer: Initializer to use for the bias parameters.
init(
depthwiseFilterShape: (Int, Int, Int, Int),
pointwiseFilterShape: (Int, Int, Int, Int),
strides: (Int, Int) = (1, 1),
padding: Padding = .valid,
activation: @escaping Activation = identity,
depthwiseFilterInitializer: ParameterInitializer<Scalar> = glorotUniform(),
pointwiseFilterInitializer: ParameterInitializer<Scalar> = glorotUniform(),
biasInitializer: ParameterInitializer<Scalar> = zeros()
) {
let depthwiseFilterTensorShape = TensorShape([
depthwiseFilterShape.0, depthwiseFilterShape.1, depthwiseFilterShape.2,
depthwiseFilterShape.3])
let pointwiseFilterTensorShape = TensorShape([
pointwiseFilterShape.0, pointwiseFilterShape.1, pointwiseFilterShape.2,
pointwiseFilterShape.3])
self.init(
depthwiseFilter: depthwiseFilterInitializer(depthwiseFilterTensorShape),
pointwiseFilter: pointwiseFilterInitializer(pointwiseFilterTensorShape),
bias: biasInitializer([pointwiseFilterShape.3]),
activation: activation,
strides: strides,
padding: padding)
}
}
19 changes: 19 additions & 0 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,24 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}

func testSeparableConv2D() {
let depthwiseFilter = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
let pointwiseFilter = Tensor(shape: [1, 1, 4, 1], scalars: (0..<4).map(Float.init))
let bias = Tensor<Float>([4])
let layer = SeparableConv2D<Float>(depthwiseFilter: depthwiseFilter,
pointwiseFilter: pointwiseFilter,
bias: bias,
activation: identity,
strides: (2, 2),
padding: .valid)
let input = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
let output = layer.inferring(from: input)
let expected = Tensor<Float>(shape: [2, 1, 1, 1],
scalars: [1016, 2616])
XCTAssertEqual(output, expected)
}


func testZeroPadding1D() {
let input = Tensor<Float>([0.0, 1.0, 2.0])
let layer = ZeroPadding1D<Float>(padding: 2)
Expand Down Expand Up @@ -538,6 +556,7 @@ final class LayerTests: XCTestCase {
("testConv2DDilation", testConv2DDilation),
("testConv3D", testConv3D),
("testDepthConv2D", testDepthConv2D),
("testSeparableConv2D", testSeparableConv2D),
("testZeroPadding1D", testZeroPadding1D),
("testZeroPadding2D", testZeroPadding2D),
("testZeroPadding3D", testZeroPadding3D),
Expand Down