Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 036c471

Browse files
Shashi456rxwei
authored andcommitted
Add DepthwiseConv2D (#241)
1 parent cf50393 commit 036c471

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,119 @@ public extension TransposedConv2D {
485485
padding: padding)
486486
}
487487
}
488+
489+
/// A 2-D depthwise convolution layer.
490+
///
491+
/// This layer creates seperable convolution filters that are convolved with the layer input to produce a
492+
/// tensor of outputs.
493+
@frozen
494+
public struct DepthwiseConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
495+
/// The 4-D convolution kernel.
496+
public var filter: Tensor<Scalar>
497+
/// The bias vector.
498+
public var bias: Tensor<Scalar>
499+
/// An activation function.
500+
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
501+
/// The element-wise activation function.
502+
@noDerivative public let activation: Activation
503+
/// The strides of the sliding window for spatial dimensions.
504+
@noDerivative public let strides: (Int, Int)
505+
/// The padding algorithm for convolution.
506+
@noDerivative public let padding: Padding
507+
508+
/// Creates a `DepthwiseConv2D` layer with the specified filter, bias, activation function, strides, and
509+
/// padding.
510+
///
511+
/// - Parameters:
512+
/// - filter: The 4-D convolution kernel.
513+
/// - bias: The bias vector.
514+
/// - activation: The element-wise activation function.
515+
/// - strides: The strides of the sliding window for spatial dimensions.
516+
/// - padding: The padding algorithm for convolution.
517+
public init(
518+
filter: Tensor<Scalar>,
519+
bias: Tensor<Scalar>,
520+
activation: @escaping Activation,
521+
strides: (Int, Int),
522+
padding: Padding
523+
) {
524+
self.filter = filter
525+
self.bias = bias
526+
self.activation = activation
527+
self.strides = strides
528+
self.padding = padding
529+
}
530+
531+
/// Returns the output obtained from applying the layer to the given input.
532+
///
533+
/// - Parameter input: The input to the layer.
534+
/// - Returns: The output.
535+
@differentiable
536+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
537+
return activation(depthwiseConv2D(input, filter: filter,
538+
strides: (1, strides.0, strides.1, 1),
539+
padding: padding) + bias)
540+
}
541+
}
542+
543+
public extension DepthwiseConv2D {
544+
/// Creates a `DepthwiseConv2D` layer with the specified filter shape, strides, padding, and
545+
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
546+
/// initialization with the specified generator. The bias vector is initialized with zeros.
547+
///
548+
/// - Parameters:
549+
/// - filterShape: The shape of the 4-D convolution kernel.
550+
/// - strides: The strides of the sliding window for spatial/spatio-temporal dimensions.
551+
/// - padding: The padding algorithm for convolution.
552+
/// - activation: The element-wise activation function.
553+
/// - generator: The random number generator for initialization.
554+
///
555+
/// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random
556+
/// initialization.
557+
init<G: RandomNumberGenerator>(
558+
filterShape: (Int, Int, Int, Int),
559+
strides: (Int, Int) = (1, 1),
560+
padding: Padding = .valid,
561+
activation: @escaping Activation = identity,
562+
generator: inout G
563+
) {
564+
let filterTensorShape = TensorShape([
565+
filterShape.0, filterShape.1, filterShape.2, filterShape.3])
566+
self.init(
567+
filter: Tensor(glorotUniform: filterTensorShape, generator: &generator),
568+
bias: Tensor(zeros: [filterShape.3]),
569+
activation: activation,
570+
strides: strides,
571+
padding: padding)
572+
}
573+
}
574+
575+
public extension DepthwiseConv2D {
576+
/// Creates a `depthwiseConv2D` layer with the specified filter shape, strides, padding, and
577+
/// element-wise activation function. The filter tensor is initialized using Glorot uniform
578+
/// initialization with the specified seed. The bias vector is initialized with zeros.
579+
///
580+
/// - Parameters:
581+
/// - filterShape: The shape of the 4-D convolution kernel.
582+
/// - strides: The strides of the sliding window for spatial/spatio-temporal dimensions.
583+
/// - padding: The padding algorithm for convolution.
584+
/// - activation: The element-wise activation function.
585+
/// - seed: The random seed for initialization. The default value is random.
586+
init(
587+
filterShape: (Int, Int, Int, Int),
588+
strides: (Int, Int) = (1, 1),
589+
padding: Padding = .valid,
590+
activation: @escaping Activation = identity,
591+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
592+
Int32.random(in: Int32.min..<Int32.max))
593+
) {
594+
let filterTensorShape = TensorShape([
595+
filterShape.0, filterShape.1, filterShape.2, filterShape.3])
596+
self.init(
597+
filter: Tensor(glorotUniform: filterTensorShape, seed: seed),
598+
bias: Tensor(zeros: [filterShape.3]),
599+
activation: activation,
600+
strides: strides,
601+
padding: padding)
602+
}
603+
}

Sources/TensorFlow/Operators/NN.swift

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,101 @@ func _vjpConv3D<Scalar: TensorFlowFloatingPoint>(
291291
})
292292
}
293293

294+
/// TensorFlow builtin depthwiseConv2D gradient helper for the input.
295+
@differentiable(wrt: (x, filter), vjp: _vjpdepthwiseConv2dBackpropInput)
296+
@usableFromInline
297+
func depthwiseConv2dBackpropInput<Scalar: TensorFlowFloatingPoint>(
298+
_ x: Tensor<Scalar>,
299+
shape: Tensor<Int32>,
300+
filter: Tensor<Scalar>,
301+
strides: (Int, Int, Int, Int),
302+
padding: Padding
303+
) -> Tensor<Scalar> {
304+
return Raw.depthwiseConv2dNativeBackpropInput(
305+
inputSizes: shape,
306+
filter: filter,
307+
outBackprop: x,
308+
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)],
309+
padding: padding.raw)
310+
}
311+
312+
/// TensorFlow builtin depthwiseConv2D gradient helper for the filter.
313+
@differentiable(wrt: (x, input), vjp: _vjpdepthwiseConv2dBackpropFilter)
314+
@usableFromInline
315+
func depthwiseConv2dBackpropFilter<Scalar: TensorFlowFloatingPoint>(
316+
_ x: Tensor<Scalar>,
317+
input: Tensor<Scalar>,
318+
filterSizes: Tensor<Int32>,
319+
strides: (Int, Int, Int, Int),
320+
padding: Padding
321+
) -> Tensor<Scalar> {
322+
return Raw.depthwiseConv2dNativeBackpropFilter(
323+
x,
324+
filterSizes: filterSizes,
325+
outBackprop: x,
326+
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)],
327+
padding: padding.raw)
328+
}
329+
330+
@usableFromInline
331+
func _vjpdepthwiseConv2dBackpropInput<Scalar: TensorFlowFloatingPoint>(
332+
_ x: Tensor<Scalar>,
333+
_ shape: Tensor<Int32>,
334+
_ filter: Tensor<Scalar>,
335+
_ strides: (Int, Int, Int, Int),
336+
_ padding: Padding
337+
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
338+
let value = depthwiseConv2dBackpropInput(x, shape: shape, filter: filter, strides: strides,
339+
padding: padding)
340+
return (value, { v in
341+
return (
342+
depthwiseConv2dBackpropFilter(x, input: v, filterSizes: shape, strides: strides,
343+
padding: padding),
344+
depthwiseConv2D(v, filter: filter, strides: strides, padding: padding)
345+
)
346+
})
347+
}
348+
349+
@usableFromInline
350+
func _vjpdepthwiseConv2dBackpropFilter<Scalar: TensorFlowFloatingPoint>(
351+
_ x: Tensor<Scalar>,
352+
_ input: Tensor<Scalar>,
353+
_ filterSizes: Tensor<Int32>,
354+
_ strides: (Int, Int, Int, Int),
355+
_ padding: Padding
356+
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
357+
let value = depthwiseConv2dBackpropFilter(x, input: input, filterSizes: filterSizes,
358+
strides: strides, padding: padding)
359+
return (value, { v in
360+
return (
361+
depthwiseConv2dBackpropInput(x, shape: filterSizes, filter: v, strides: strides,
362+
padding: padding),
363+
depthwiseConv2D(input, filter: v, strides: strides, padding: padding)
364+
)
365+
})
366+
}
367+
368+
@usableFromInline
369+
func _vjpDepthwiseConv2D<Scalar: TensorFlowFloatingPoint>(
370+
_ x: Tensor<Scalar>,
371+
filter: Tensor<Scalar>,
372+
strides: (Int, Int, Int, Int),
373+
padding: Padding
374+
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
375+
let value = depthwiseConv2D(x, filter: filter, strides: strides,
376+
padding: padding)
377+
return (value, { v in
378+
return (
379+
depthwiseConv2dBackpropInput(v, shape: x.shapeTensor, filter: filter,
380+
strides: strides, padding: padding
381+
),
382+
depthwiseConv2dBackpropFilter(v, input: x, filterSizes: filter.shapeTensor,
383+
strides: strides, padding: padding
384+
)
385+
)
386+
})
387+
}
388+
294389
@usableFromInline
295390
func _vjpMaxPool2D<Scalar: TensorFlowFloatingPoint>(
296391
_ x: Tensor<Scalar>,
@@ -432,6 +527,29 @@ public func conv3D<Scalar: TensorFlowFloatingPoint>(
432527
padding: padding.raw)
433528
}
434529

530+
/// Computes a 2-D depthwise convolution with the specified input, filter, strides, and padding.
531+
///
532+
/// - Parameters:
533+
/// - input: The input.
534+
/// - filter: The depthwise convolution filter.
535+
/// - strides: The strides of the sliding filter for each dimension of the input.
536+
/// - padding: The padding for the operation.
537+
/// - Precondition: `input` must have rank 4.
538+
/// - Precondition: `filter` must have rank 4.
539+
@differentiable(wrt: (input, filter), vjp: _vjpDepthwiseConv2D)
540+
public func depthwiseConv2D<Scalar: TensorFlowFloatingPoint>(
541+
_ input: Tensor<Scalar>,
542+
filter: Tensor<Scalar>,
543+
strides: (Int, Int, Int, Int),
544+
padding: Padding
545+
) -> Tensor<Scalar> {
546+
return Raw.depthwiseConv2dNative(
547+
input,
548+
filter: filter,
549+
strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2),Int32(strides.3)],
550+
padding: padding.raw)
551+
}
552+
435553
/// Computes a 2-D max pooling, with the specified filter sizes, strides, and
436554
/// padding.
437555
///

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ final class LayerTests: XCTestCase {
5252
XCTAssertEqual(output, expected)
5353
}
5454

55+
func testDepthConv2D() {
56+
let filter = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
57+
let bias = Tensor<Float>([1, 2, 3, 4])
58+
let layer = DepthwiseConv2D<Float>(filter: filter, bias: bias, activation: identity,
59+
strides: (2, 2), padding: .valid)
60+
let input = Tensor(shape: [1, 1, 8, 2], scalars: (0..<16).map(Float.init))
61+
let output = layer.inferring(from: input)
62+
let expected = Tensor<Float>(shape: [1, 1, 4, 4],
63+
scalars: [9, 12, 23, 28, 25, 36, 55, 68, 41, 60, 87, 108,
64+
57, 84, 119, 148])
65+
XCTAssertEqual(output, expected)
66+
}
67+
5568
func testMaxPool1D() {
5669
let layer = MaxPool1D<Float>(poolSize: 3, stride: 1, padding: .valid)
5770
let input = Tensor<Float>([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2)
@@ -241,6 +254,7 @@ final class LayerTests: XCTestCase {
241254
("testConv1D", testConv1D),
242255
("testConv2D", testConv2D),
243256
("testConv3D", testConv3D),
257+
("testDepthConv2D", testDepthConv2D),
244258
("testMaxPool1D", testMaxPool1D),
245259
("testMaxPool2D", testMaxPool2D),
246260
("testMaxPool3D", testMaxPool3D),

0 commit comments

Comments
 (0)