@@ -138,7 +138,7 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
138
138
@noDerivative public let padding : Padding
139
139
/// The dilation factor for spatial dimensions.
140
140
@noDerivative public let dilations : ( Int , Int )
141
-
141
+
142
142
/// The element-wise activation function type.
143
143
public typealias Activation = @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar >
144
144
@@ -470,7 +470,7 @@ public struct DepthwiseConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
470
470
/// The element-wise activation function type.
471
471
public typealias Activation = @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar >
472
472
473
- /// Creates a `DepthwiseConv2D` layer with the specified filter, bias, activation function,
473
+ /// Creates a `DepthwiseConv2D` layer with the specified filter, bias, activation function,
474
474
/// strides, and padding.
475
475
///
476
476
/// - Parameters:
@@ -631,3 +631,107 @@ public struct ZeroPadding3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer
631
631
return input. padded ( forSizes: [ padding. 0 , padding. 1 , padding. 2 ] )
632
632
}
633
633
}
634
+
635
+ /// A 2-D Separable convolution layer.
636
+ ///
637
+ /// This layer Performs a depthwise convolution that acts separately on channels followed by
638
+ /// a pointwise convolution that mixes channels.
639
+ @frozen
640
+ public struct SeparableConv2D < Scalar: TensorFlowFloatingPoint > : Layer {
641
+ /// The 4-D depthwise convolution kernel.
642
+ public var depthwiseFilter : Tensor < Scalar >
643
+ /// The 4-D pointwise convolution kernel.
644
+ public var pointwiseFilter : Tensor < Scalar >
645
+ /// The bias vector.
646
+ public var bias : Tensor < Scalar >
647
+ /// The element-wise activation function.
648
+ @noDerivative public let activation : Activation
649
+ /// The strides of the sliding window for spatial dimensions.
650
+ @noDerivative public let strides : ( Int , Int )
651
+ /// The padding algorithm for convolution.
652
+ @noDerivative public let padding : Padding
653
+
654
+ /// The element-wise activation function type.
655
+ public typealias Activation = @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar >
656
+
657
+ /// Creates a `SeparableConv2D` layer with the specified filter, bias, activation function,
658
+ /// strides, and padding.
659
+ ///
660
+ /// - Parameters:
661
+ /// - depthwiseFilter: The 4-D depthwise convolution kernel.
662
+ /// - pointwiseFilter: The 4-D pointwise convolution kernel.
663
+ /// - bias: The bias vector.
664
+ /// - activation: The element-wise activation function.
665
+ /// - strides: The strides of the sliding window for spatial dimensions.
666
+ /// - padding: The padding algorithm for convolution.
667
+ public init (
668
+ depthwiseFilter: Tensor < Scalar > ,
669
+ pointwiseFilter: Tensor < Scalar > ,
670
+ bias: Tensor < Scalar > ,
671
+ activation: @escaping Activation = identity,
672
+ strides: ( Int , Int ) = ( 1 , 1 ) ,
673
+ padding: Padding = . valid
674
+ ) {
675
+ self . depthwiseFilter = depthwiseFilter
676
+ self . pointwiseFilter = pointwiseFilter
677
+ self . bias = bias
678
+ self . activation = activation
679
+ self . strides = strides
680
+ self . padding = padding
681
+ }
682
+
683
+ /// Returns the output obtained from applying the layer to the given input.
684
+ ///
685
+ /// - Parameter input: The input to the layer.
686
+ /// - Returns: The output.
687
+ @differentiable
688
+ public func callAsFunction( _ input: Tensor < Scalar > ) -> Tensor < Scalar > {
689
+ depthwise = depthwiseConv2D ( input,
690
+ filter: depthwiseFilter,
691
+ strides: ( 1 , strides. 0 , strides. 1 , 1 ) ,
692
+ padding: padding)
693
+ return activation ( Conv2D (
694
+ depthwise,
695
+ filter: pointwiseFilter,
696
+ strides: ( 1 , 1 , 1 , 1 ) ,
697
+ padding: padding) + bias)
698
+ }
699
+ }
700
+
701
+ public extension SeparableConv2D {
702
+ /// Creates a `SeparableConv2D` layer with the specified depthwise and pointwise filter shape,
703
+ /// strides, padding, and element-wise activation function.
704
+ ///
705
+ /// - Parameters:
706
+ /// - depthwiseFilterShape: The shape of the 4-D depthwise convolution kernel.
707
+ /// - pointwiseFilterShape: The shape of the 4-D pointwise convolution kernel.
708
+ /// - strides: The strides of the sliding window for spatial/spatio-temporal dimensions.
709
+ /// - padding: The padding algorithm for convolution.
710
+ /// - activation: The element-wise activation function.
711
+ /// - filterInitializer: Initializer to use for the filter parameters.
712
+ /// - biasInitializer: Initializer to use for the bias parameters.
713
+ init (
714
+ depthwiseFilterShape: ( Int , Int , Int , Int ) ,
715
+ pointwiseFilterShape: ( Int , Int , Int , Int ) ,
716
+ strides: ( Int , Int ) = ( 1 , 1 ) ,
717
+ padding: Padding = . valid,
718
+ activation: @escaping Activation = identity,
719
+ depthwiseFilterInitializer: ParameterInitializer < Scalar > = glorotUniform ( ) ,
720
+ pointwiseFilterInitializer: ParameterInitializer < Scalar > = glorotUniform ( ) ,
721
+ biasInitializer: ParameterInitializer < Scalar > = zeros ( )
722
+ ) {
723
+ let depthwiseFilterTensorShape = TensorShape ( [
724
+ depthwiseFilterShape. 0 , depthwiseFilterShape. 1 , depthwiseFilterShape. 2 ,
725
+ depthwiseFilterShape. 3 ] )
726
+ let pointwiseFilterTensorShape = TensorShape ( [
727
+ pointwiseFilterShape. 0 , pointwiseFilterShape. 1 , pointwiseFilterShape. 2 ,
728
+ pointwiseFilterShape. 3 ] )
729
+ self . init (
730
+ depthwiseFilter: depthwiseFilterInitializer ( depthwiseFilterTensorShape) ,
731
+ pointwiseFilter: pointwiseFilterInitializer ( pointwiseFilterTensorShape) ,
732
+ bias: biasInitializer ( [ pointwiseFilterShape. 3 ] ) ,
733
+ activation: activation,
734
+ strides: strides,
735
+ padding: padding)
736
+ }
737
+ }
0 commit comments