Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 2a8d756

Browse files
Shashi456rxwei
authored andcommitted
Adding GlobalAveragePooling1D (#66)
1 parent b64ee81 commit 2a8d756

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,12 +1009,32 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
10091009
}
10101010
}
10111011

1012+
1013+
/// A global average pooling layer for temporal data.
1014+
@_fixed_layout
1015+
public struct GlobalAveragePooling1D<Scalar: TensorFlowFloatingPoint>: Layer {
1016+
/// Creates a global average pooling layer.
1017+
public init() {}
1018+
1019+
/// Returns the output obtained from applying the layer to the given input.
1020+
///
1021+
/// - Parameters:
1022+
/// - input: The input to the layer.
1023+
/// - context: The contextual information for the layer application, e.g. the current learning
1024+
/// phase.
1025+
/// - Returns: The output
1026+
@differentiable
1027+
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
1028+
return input.mean(alongAxes: 1).reshaped(to: [input.shape[0], input.shape[2]])
1029+
}
1030+
}
1031+
10121032
/// A global average pooling layer for spatial data.
10131033
@_fixed_layout
10141034
public struct GlobalAveragePooling2D<Scalar: TensorFlowFloatingPoint>: Layer {
10151035
/// Creates a global average pooling layer.
10161036
public init() {}
1017-
1037+
10181038
/// Returns the output obtained from applying the layer to the given input.
10191039
///
10201040
/// - Parameters:

0 commit comments

Comments
 (0)