Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b64ee81

Browse files
tanmayb123rxwei
authored andcommitted
Add GlobalAveragePooling2D (#65)
1 parent ee6172c commit b64ee81

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,25 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
10091009
}
10101010
}
10111011

1012+
/// A global average pooling layer for spatial data.
1013+
@_fixed_layout
1014+
public struct GlobalAveragePooling2D<Scalar: TensorFlowFloatingPoint>: Layer {
1015+
/// Creates a global average pooling layer.
1016+
public init() {}
1017+
1018+
/// Returns the output obtained from applying the layer to the given input.
1019+
///
1020+
/// - Parameters:
1021+
/// - input: The input to the layer.
1022+
/// - context: The contextual information for the layer application, e.g. the current learning
1023+
/// phase.
1024+
/// - Returns: The output.
1025+
@differentiable
1026+
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
1027+
return input.mean(alongAxes: [1, 2]).reshaped(to: [input.shape[0], input.shape[3]])
1028+
}
1029+
}
1030+
10121031
/// A layer that applies layer normalization over a mini-batch of inputs.
10131032
///
10141033
/// Reference: [Layer Normalization](https://arxiv.org/abs/1607.06450).

0 commit comments

Comments
 (0)