This repository was archived by the owner on Jul 1, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +21
-1
lines changed Expand file tree Collapse file tree 1 file changed +21
-1
lines changed Original file line number Diff line number Diff line change @@ -1009,12 +1009,32 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
1009
1009
}
1010
1010
}
1011
1011
1012
+
1013
+ /// A global average pooling layer for temporal data.
1014
+ @_fixed_layout
1015
+ public struct GlobalAveragePooling1D < Scalar: TensorFlowFloatingPoint > : Layer {
1016
+ /// Creates a global average pooling layer.
1017
+ public init ( ) { }
1018
+
1019
+ /// Returns the output obtained from applying the layer to the given input.
1020
+ ///
1021
+ /// - Parameters:
1022
+ /// - input: The input to the layer.
1023
+ /// - context: The contextual information for the layer application, e.g. the current learning
1024
+ /// phase.
1025
+ /// - Returns: The output
1026
+ @differentiable
1027
+ public func applied( to input: Tensor < Scalar > , in _: Context ) -> Tensor < Scalar > {
1028
+ return input. mean ( alongAxes: 1 ) . reshaped ( to: [ input. shape [ 0 ] , input. shape [ 2 ] ] )
1029
+ }
1030
+ }
1031
+
1012
1032
/// A global average pooling layer for spatial data.
1013
1033
@_fixed_layout
1014
1034
public struct GlobalAveragePooling2D < Scalar: TensorFlowFloatingPoint > : Layer {
1015
1035
/// Creates a global average pooling layer.
1016
1036
public init ( ) { }
1017
-
1037
+
1018
1038
/// Returns the output obtained from applying the layer to the given input.
1019
1039
///
1020
1040
/// - Parameters:
You can’t perform that action at this time.
0 commit comments