Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add DepthwiseConv2D #241

Merged
merged 12 commits into from
Jun 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions Sources/TensorFlow/Layers/Convolutional.swift
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,119 @@ public extension TransposedConv2D {
padding: padding)
}
}

/// A 2-D depthwise convolution layer.
///
/// This layer creates seperable convolution filters that are convolved with the layer input to produce a
/// tensor of outputs.
@frozen
public struct DepthwiseConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// The 4-D convolution kernel.
public var filter: Tensor<Scalar>
/// The bias vector.
public var bias: Tensor<Scalar>
/// An activation function.
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
/// The element-wise activation function.
@noDerivative public let activation: Activation
/// The strides of the sliding window for spatial dimensions.
@noDerivative public let strides: (Int, Int)
/// The padding algorithm for convolution.
@noDerivative public let padding: Padding

/// Creates a `DepthwiseConv2D` layer with the specified filter, bias, activation function, strides, and
/// padding.
///
/// - Parameters:
/// - filter: The 4-D convolution kernel.
/// - bias: The bias vector.
/// - activation: The element-wise activation function.
/// - strides: The strides of the sliding window for spatial dimensions.
/// - padding: The padding algorithm for convolution.
public init(
filter: Tensor<Scalar>,
bias: Tensor<Scalar>,
activation: @escaping Activation,
strides: (Int, Int),
padding: Padding
) {
self.filter = filter
self.bias = bias
self.activation = activation
self.strides = strides
self.padding = padding
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return activation(depthwiseConv2D(input, filter: filter,
strides: (1, strides.0, strides.1, 1),
padding: padding) + bias)
}
}

public extension DepthwiseConv2D {
/// Creates a `DepthwiseConv2D` layer with the specified filter shape, strides, padding, and
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
/// initialization with the specified generator. The bias vector is initialized with zeros.
///
/// - Parameters:
/// - filterShape: The shape of the 4-D convolution kernel.
/// - strides: The strides of the sliding window for spatial/spatio-temporal dimensions.
/// - padding: The padding algorithm for convolution.
/// - activation: The element-wise activation function.
/// - generator: The random number generator for initialization.
///
/// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random
/// initialization.
init<G: RandomNumberGenerator>(
filterShape: (Int, Int, Int, Int),
strides: (Int, Int) = (1, 1),
padding: Padding = .valid,
activation: @escaping Activation = identity,
generator: inout G
) {
let filterTensorShape = TensorShape([
filterShape.0, filterShape.1, filterShape.2, filterShape.3])
self.init(
filter: Tensor(glorotUniform: filterTensorShape, generator: &generator),
bias: Tensor(zeros: [filterShape.3]),
activation: activation,
strides: strides,
padding: padding)
}
}

public extension DepthwiseConv2D {
/// Creates a `depthwiseConv2D` layer with the specified filter shape, strides, padding, and
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
/// initialization with the specified seed. The bias vector is initialized with zeros.
///
/// - Parameters:
/// - filterShape: The shape of the 4-D convolution kernel.
/// - strides: The strides of the sliding window for spatial/spatio-temporal dimensions.
/// - padding: The padding algorithm for convolution.
/// - activation: The element-wise activation function.
/// - seed: The random seed for initialization. The default value is random.
init(
filterShape: (Int, Int, Int, Int),
strides: (Int, Int) = (1, 1),
padding: Padding = .valid,
activation: @escaping Activation = identity,
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
Int32.random(in: Int32.min..<Int32.max))
) {
let filterTensorShape = TensorShape([
filterShape.0, filterShape.1, filterShape.2, filterShape.3])
self.init(
filter: Tensor(glorotUniform: filterTensorShape, seed: seed),
bias: Tensor(zeros: [filterShape.3]),
activation: activation,
strides: strides,
padding: padding)
}
}
118 changes: 118 additions & 0 deletions Sources/TensorFlow/Operators/NN.swift
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,101 @@ func _vjpConv3D<Scalar: TensorFlowFloatingPoint>(
})
}

/// TensorFlow builtin depthwiseConv2D gradient helper for the input.
@differentiable(wrt: (x, filter), vjp: _vjpdepthwiseConv2dBackpropInput)
@usableFromInline
func depthwiseConv2dBackpropInput<Scalar: TensorFlowFloatingPoint>(
_ x: Tensor<Scalar>,
shape: Tensor<Int32>,
filter: Tensor<Scalar>,
strides: (Int, Int, Int, Int),
padding: Padding
) -> Tensor<Scalar> {
return Raw.depthwiseConv2dNativeBackpropInput(
inputSizes: shape,
filter: filter,
outBackprop: x,
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)],
padding: padding.raw)
}

/// TensorFlow builtin depthwiseConv2D gradient helper for the filter.
@differentiable(wrt: (x, input), vjp: _vjpdepthwiseConv2dBackpropFilter)
@usableFromInline
func depthwiseConv2dBackpropFilter<Scalar: TensorFlowFloatingPoint>(
_ x: Tensor<Scalar>,
input: Tensor<Scalar>,
filterSizes: Tensor<Int32>,
strides: (Int, Int, Int, Int),
padding: Padding
) -> Tensor<Scalar> {
return Raw.depthwiseConv2dNativeBackpropFilter(
x,
filterSizes: filterSizes,
outBackprop: x,
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)],
padding: padding.raw)
}

@usableFromInline
func _vjpdepthwiseConv2dBackpropInput<Scalar: TensorFlowFloatingPoint>(
_ x: Tensor<Scalar>,
_ shape: Tensor<Int32>,
_ filter: Tensor<Scalar>,
_ strides: (Int, Int, Int, Int),
_ padding: Padding
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
let value = depthwiseConv2dBackpropInput(x, shape: shape, filter: filter, strides: strides,
padding: padding)
return (value, { v in
return (
depthwiseConv2dBackpropFilter(x, input: v, filterSizes: shape, strides: strides,
padding: padding),
depthwiseConv2D(v, filter: filter, strides: strides, padding: padding)
)
})
}

@usableFromInline
func _vjpdepthwiseConv2dBackpropFilter<Scalar: TensorFlowFloatingPoint>(
_ x: Tensor<Scalar>,
_ input: Tensor<Scalar>,
_ filterSizes: Tensor<Int32>,
_ strides: (Int, Int, Int, Int),
_ padding: Padding
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
let value = depthwiseConv2dBackpropFilter(x, input: input, filterSizes: filterSizes,
strides: strides, padding: padding)
return (value, { v in
return (
depthwiseConv2dBackpropInput(x, shape: filterSizes, filter: v, strides: strides,
padding: padding),
depthwiseConv2D(input, filter: v, strides: strides, padding: padding)
)
})
}

@usableFromInline
func _vjpDepthwiseConv2D<Scalar: TensorFlowFloatingPoint>(
_ x: Tensor<Scalar>,
filter: Tensor<Scalar>,
strides: (Int, Int, Int, Int),
padding: Padding
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
let value = depthwiseConv2D(x, filter: filter, strides: strides,
padding: padding)
return (value, { v in
return (
depthwiseConv2dBackpropInput(v, shape: x.shapeTensor, filter: filter,
strides: strides, padding: padding
),
depthwiseConv2dBackpropFilter(v, input: x, filterSizes: filter.shapeTensor,
strides: strides, padding: padding
)
)
})
}

@usableFromInline
func _vjpMaxPool2D<Scalar: TensorFlowFloatingPoint>(
_ x: Tensor<Scalar>,
Expand Down Expand Up @@ -432,6 +527,29 @@ public func conv3D<Scalar: TensorFlowFloatingPoint>(
padding: padding.raw)
}

/// Computes a 2-D depthwise convolution with the specified input, filter, strides, and padding.
///
/// - Parameters:
/// - input: The input.
/// - filter: The depthwise convolution filter.
/// - strides: The strides of the sliding filter for each dimension of the input.
/// - padding: The padding for the operation.
/// - Precondition: `input` must have rank 4.
/// - Precondition: `filter` must have rank 4.
@differentiable(wrt: (input, filter), vjp: _vjpDepthwiseConv2D)
public func depthwiseConv2D<Scalar: TensorFlowFloatingPoint>(
_ input: Tensor<Scalar>,
filter: Tensor<Scalar>,
strides: (Int, Int, Int, Int),
padding: Padding
) -> Tensor<Scalar> {
return Raw.depthwiseConv2dNative(
input,
filter: filter,
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2),Int32(strides.3)],
padding: padding.raw)
}

/// Computes a 2-D max pooling, with the specified filter sizes, strides, and
/// padding.
///
Expand Down
14 changes: 14 additions & 0 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}

func testDepthConv2D() {
let filter = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
let bias = Tensor<Float>([1, 2, 3, 4])
let layer = DepthwiseConv2D<Float>(filter: filter, bias: bias, activation: identity,
strides: (2, 2), padding: .valid)
let input = Tensor(shape: [1, 1, 8, 2], scalars: (0..<16).map(Float.init))
let output = layer.inferring(from: input)
let expected = Tensor<Float>(shape: [1, 1, 4, 4],
scalars: [9, 12, 23, 28, 25, 36, 55, 68, 41, 60, 87, 108,
57, 84, 119, 148])
XCTAssertEqual(output, expected)
}

func testMaxPool1D() {
let layer = MaxPool1D<Float>(poolSize: 3, stride: 1, padding: .valid)
let input = Tensor<Float>([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2)
Expand Down Expand Up @@ -241,6 +254,7 @@ final class LayerTests: XCTestCase {
("testConv1D", testConv1D),
("testConv2D", testConv2D),
("testConv3D", testConv3D),
("testDepthConv2D", testDepthConv2D),
("testMaxPool1D", testMaxPool1D),
("testMaxPool2D", testMaxPool2D),
("testMaxPool3D", testMaxPool3D),
Expand Down