Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 9fab414

Browse files
committed
Add Separable Conv2D
1 parent a3b3aa0 commit 9fab414

File tree

1 file changed

+106
-2
lines changed

1 file changed

+106
-2
lines changed

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
138138
@noDerivative public let padding: Padding
139139
/// The dilation factor for spatial dimensions.
140140
@noDerivative public let dilations: (Int, Int)
141-
141+
142142
/// The element-wise activation function type.
143143
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
144144

@@ -470,7 +470,7 @@ public struct DepthwiseConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
470470
/// The element-wise activation function type.
471471
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
472472

473-
/// Creates a `DepthwiseConv2D` layer with the specified filter, bias, activation function,
473+
/// Creates a `DepthwiseConv2D` layer with the specified filter, bias, activation function,
474474
/// strides, and padding.
475475
///
476476
/// - Parameters:
@@ -631,3 +631,107 @@ public struct ZeroPadding3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer
631631
return input.padded(forSizes: [padding.0, padding.1, padding.2])
632632
}
633633
}
634+
635+
/// A 2-D Separable convolution layer.
636+
///
637+
/// This layer Performs a depthwise convolution that acts separately on channels followed by
638+
/// a pointwise convolution that mixes channels.
639+
@frozen
640+
public struct SeparableConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
641+
/// The 4-D depthwise convolution kernel.
642+
public var depthwiseFilter: Tensor<Scalar>
643+
/// The 4-D pointwise convolution kernel.
644+
public var pointwiseFilter: Tensor<Scalar>
645+
/// The bias vector.
646+
public var bias: Tensor<Scalar>
647+
/// The element-wise activation function.
648+
@noDerivative public let activation: Activation
649+
/// The strides of the sliding window for spatial dimensions.
650+
@noDerivative public let strides: (Int, Int)
651+
/// The padding algorithm for convolution.
652+
@noDerivative public let padding: Padding
653+
654+
/// The element-wise activation function type.
655+
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
656+
657+
/// Creates a `SeparableConv2D` layer with the specified filter, bias, activation function,
658+
/// strides, and padding.
659+
///
660+
/// - Parameters:
661+
/// - depthwiseFilter: The 4-D depthwise convolution kernel.
662+
/// - pointwiseFilter: The 4-D pointwise convolution kernel.
663+
/// - bias: The bias vector.
664+
/// - activation: The element-wise activation function.
665+
/// - strides: The strides of the sliding window for spatial dimensions.
666+
/// - padding: The padding algorithm for convolution.
667+
public init(
668+
depthwiseFilter: Tensor<Scalar>,
669+
pointwiseFilter: Tensor<Scalar>,
670+
bias: Tensor<Scalar>,
671+
activation: @escaping Activation = identity,
672+
strides: (Int, Int) = (1, 1),
673+
padding: Padding = .valid
674+
) {
675+
self.depthwiseFilter = depthwiseFilter
676+
self.pointwiseFilter = pointwiseFilter
677+
self.bias = bias
678+
self.activation = activation
679+
self.strides = strides
680+
self.padding = padding
681+
}
682+
683+
/// Returns the output obtained from applying the layer to the given input.
684+
///
685+
/// - Parameter input: The input to the layer.
686+
/// - Returns: The output.
687+
@differentiable
688+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
689+
depthwise = depthwiseConv2D( input,
690+
filter: depthwiseFilter,
691+
strides: (1, strides.0, strides.1, 1),
692+
padding: padding)
693+
return activation(Conv2D(
694+
depthwise,
695+
filter: pointwiseFilter,
696+
strides: (1, 1, 1, 1),
697+
padding: padding) + bias)
698+
}
699+
}
700+
701+
public extension SeparableConv2D {
702+
/// Creates a `SeparableConv2D` layer with the specified depthwise and pointwise filter shape,
703+
/// strides, padding, and element-wise activation function.
704+
///
705+
/// - Parameters:
706+
/// - depthwiseFilterShape: The shape of the 4-D depthwise convolution kernel.
707+
/// - pointwiseFilterShape: The shape of the 4-D pointwise convolution kernel.
708+
/// - strides: The strides of the sliding window for spatial/spatio-temporal dimensions.
709+
/// - padding: The padding algorithm for convolution.
710+
/// - activation: The element-wise activation function.
711+
/// - filterInitializer: Initializer to use for the filter parameters.
712+
/// - biasInitializer: Initializer to use for the bias parameters.
713+
init(
714+
depthwiseFilterShape: (Int, Int, Int, Int),
715+
pointwiseFilterShape: (Int, Int, Int, Int),
716+
strides: (Int, Int) = (1, 1),
717+
padding: Padding = .valid,
718+
activation: @escaping Activation = identity,
719+
depthwiseFilterInitializer: ParameterInitializer<Scalar> = glorotUniform(),
720+
pointwiseFilterInitializer: ParameterInitializer<Scalar> = glorotUniform(),
721+
biasInitializer: ParameterInitializer<Scalar> = zeros()
722+
) {
723+
let depthwiseFilterTensorShape = TensorShape([
724+
depthwiseFilterShape.0, depthwiseFilterShape.1, depthwiseFilterShape.2,
725+
depthwiseFilterShape.3])
726+
let pointwiseFilterTensorShape = TensorShape([
727+
pointwiseFilterShape.0, pointwiseFilterShape.1, pointwiseFilterShape.2,
728+
pointwiseFilterShape.3])
729+
self.init(
730+
depthwiseFilter: depthwiseFilterInitializer(depthwiseFilterTensorShape),
731+
pointwiseFilter: pointwiseFilterInitializer(pointwiseFilterTensorShape),
732+
bias: biasInitializer([pointwiseFilterShape.3]),
733+
activation: activation,
734+
strides: strides,
735+
padding: padding)
736+
}
737+
}

0 commit comments

Comments
 (0)