Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 99add87

Browse files
Shashi456saeta
authored andcommitted
Add Avgpool3d and Maxpool3d Layers & Tests (#117)
Adds Avgpool3d and Maxpool3d layers and tests for both 2d & 3d pooling layers.
1 parent 1dc4afb commit 99add87

File tree

2 files changed

+145
-10
lines changed

2 files changed

+145
-10
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -828,11 +828,9 @@ public struct MaxPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
828828
/// - poolSize: Vertical and horizontal factors by which to downscale.
829829
/// - strides: The strides.
830830
/// - padding: The padding.
831-
public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) {
832-
self.poolSize = (1, poolSize.0, poolSize.1, 1)
833-
self.strides = (1, strides.0, strides.1, 1)
834-
self.padding = padding
835-
}
831+
self.init(poolSize: (1, poolSize.0, poolSize.1, 1),
832+
strides: (1, strides.0, strides.1, 1),
833+
padding: padding)
836834

837835
/// Returns the output obtained from applying the layer to the given input.
838836
///
@@ -845,6 +843,58 @@ public struct MaxPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
845843
}
846844
}
847845

846+
/// A max pooling layer for spatial or spatio-temporal data.
847+
@_fixed_layout
848+
public struct MaxPool3D<Scalar: TensorFlowFloatingPoint>: Layer {
849+
/// The size of the sliding reduction window for pooling.
850+
@noDerivative let poolSize: (Int, Int, Int, Int, Int)
851+
/// The strides of the sliding window for each dimension of a 5-D input.
852+
/// Strides in non-spatial dimensions must be `1`.
853+
@noDerivative let strides: (Int, Int, Int, Int, Int)
854+
/// The padding algorithm for pooling.
855+
@noDerivative let padding: Padding
856+
857+
/// Creates a max pooling layer.
858+
public init(
859+
poolSize: (Int, Int, Int, Int, Int),
860+
strides: (Int, Int, Int, Int, Int),
861+
padding: Padding
862+
) {
863+
self.poolSize = poolSize
864+
self.strides = strides
865+
self.padding = padding
866+
}
867+
868+
/// Creates a max pooling layer.
869+
///
870+
/// - Parameters:
871+
/// - poolSize: Vertical and horizontal factors by which to downscale.
872+
/// - strides: The strides.
873+
/// - padding: The padding.
874+
self.init(poolSize: (1, poolSize.0, poolSize.1, poolSize.2, 1),
875+
strides: (1, strides.0, strides.1, strides.2, 1),
876+
padding: padding)
877+
878+
/// Returns the output obtained from applying the layer to the given input.
879+
///
880+
/// - Parameter input: The input to the layer.
881+
/// - Returns: The output.
882+
@differentiable
883+
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
884+
return input.maxPooled(kernelSize: poolSize, strides: strides, padding: padding)
885+
}
886+
}
887+
888+
public extension MaxPool3D {
889+
/// Creates a max pooling layer with the specified pooling window size and stride. All
890+
/// pooling sizes and strides are the same.
891+
init(poolSize: Int, stride: Int, padding: Padding = .valid) {
892+
self.init(poolsize: (poolSize, poolSize, poolSize),
893+
strides: (stride, stride, stride),
894+
padding: padding)
895+
}
896+
}
897+
848898
/// An average pooling layer for temporal data.
849899
@_fixed_layout
850900
public struct AvgPool1D<Scalar: TensorFlowFloatingPoint>: Layer {
@@ -894,7 +944,7 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
894944
/// The padding algorithm for pooling.
895945
@noDerivative let padding: Padding
896946

897-
/// Creates a average pooling layer.
947+
/// Creates an average pooling layer.
898948
public init(
899949
poolSize: (Int, Int, Int, Int),
900950
strides: (Int, Int, Int, Int),
@@ -905,18 +955,58 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
905955
self.padding = padding
906956
}
907957

908-
/// Creates a average pooling layer.
958+
/// Creates an average pooling layer.
909959
///
910960
/// - Parameters:
911961
/// - poolSize: Vertical and horizontal factors by which to downscale.
912962
/// - strides: The strides.
913963
/// - padding: The padding.
914-
public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) {
915-
self.poolSize = (1, poolSize.0, poolSize.1, 1)
916-
self.strides = (1, strides.0, strides.1, 1)
964+
self.init(poolSize: (1, poolSize.0, poolSize.1, 1),
965+
strides: (1, strides.0, strides.1, 1),
966+
padding: padding)
967+
968+
/// Returns the output obtained from applying the layer to the given input.
969+
///
970+
/// - Parameter input: The input to the layer.
971+
/// - Returns: The output.
972+
@differentiable
973+
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
974+
return input.averagePooled(kernelSize: poolSize, strides: strides, padding: padding)
975+
}
976+
}
977+
978+
/// An average pooling layer for spatial or spatio-temporal data.
979+
@_fixed_layout
980+
public struct AvgPool3D<Scalar: TensorFlowFloatingPoint>: Layer {
981+
/// The size of the sliding reduction window for pooling.
982+
@noDerivative let poolSize: (Int, Int, Int, Int, Int)
983+
/// The strides of the sliding window for each dimension of a 5-D input.
984+
/// Strides in non-spatial dimensions must be `1`.
985+
@noDerivative let strides: (Int, Int, Int, Int, Int)
986+
/// The padding algorithm for pooling.
987+
@noDerivative let padding: Padding
988+
989+
/// Creates an average pooling layer.
990+
public init(
991+
poolSize: (Int, Int, Int, Int, Int),
992+
strides: (Int, Int, Int, Int, Int),
993+
padding: Padding
994+
) {
995+
self.poolSize = poolSize
996+
self.strides = strides
917997
self.padding = padding
918998
}
919999

1000+
/// Creates an average pooling layer.
1001+
///
1002+
/// - Parameters:
1003+
/// - poolSize: Vertical and horizontal factors by which to downscale.
1004+
/// - strides: The strides.
1005+
/// - padding: The padding.
1006+
self.init(poolSize: (1, poolSize.0, poolSize.1, poolSize.2, 1),
1007+
strides: (1, strides.0, strides.1, strides.2, 1),
1008+
padding: padding)
1009+
9201010
/// Returns the output obtained from applying the layer to the given input.
9211011
///
9221012
/// - Parameter input: The input to the layer.
@@ -927,6 +1017,15 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
9271017
}
9281018
}
9291019

1020+
public extension AvgPool3D {
1021+
/// Creates an average pooling layer with the specified pooling window size and stride. All
1022+
/// pooling sizes and strides are the same.
1023+
init(poolSize: Int, strides: Int, padding: Padding = .valid) {
1024+
self.init(poolSize: (poolSize, poolSize, poolSize),
1025+
strides: (strides, strides, strides),
1026+
padding: padding)
1027+
}
1028+
}
9301029

9311030
/// A global average pooling layer for temporal data.
9321031
@_fixed_layout

Tests/DeepLearningTests/LayerTests.swift

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@ final class LayerTests: XCTestCase {
3434
XCTAssertEqual(round(output), expected)
3535
}
3636

37+
func testMaxPool2D() {
38+
let layer = MaxPool2D<Float>(poolSize: (2, 2), strides: (1, 1), padding:.valid)
39+
let input = Tensor(shape: [1, 2, 2, 1], scalars: (0..<4).map(Float.init))
40+
let output = layer.inferring(from: input)
41+
let expected = Tensor<Float>([[[[3]]]])
42+
XCTAssertEqual(round(output), expected)
43+
}
44+
45+
func testMaxPool3D() {
46+
let layer = MaxPool3D<Float>(poolSize: (2 ,2, 2), strides: (1, 1, 1), padding:.valid)
47+
let input = Tensor(shape: [1, 2, 2, 2, 1], scalars: (0..<8).map(Float.init))
48+
let output = layer.inferring(from: input)
49+
let expected = Tensor<Float>([[[[[7]]]]])
50+
XCTAssertEqual(round(output), expected)
51+
}
52+
3753
func testAvgPool1D() {
3854
let layer = AvgPool1D<Float>(poolSize: 3, stride: 1, padding: .valid)
3955
let input = Tensor<Float>([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2)
@@ -42,6 +58,22 @@ final class LayerTests: XCTestCase {
4258
XCTAssertEqual(round(output), expected)
4359
}
4460

61+
func testAvgPool2D() {
62+
let layer = AvgPool2D<Float>(poolSize: (2, 5), strides: (1, 1), padding:.valid)
63+
let input = Tensor(shape: [1, 2, 5, 1], scalars: (0..<10).map(Float.init))
64+
let output = layer.inferring(from: input)
65+
let expected = Tensor<Float>([[[[4.5]]]])
66+
XCTAssertEqual(output, expected)
67+
}
68+
69+
func testAvgPool3D() {
70+
let layer = AvgPool3D<Float>(poolSize: (2, 4, 5), stride: (1, 1, 1), padding: .valid)
71+
let input = Tensor(shape: [1, 2, 4, 5, 1], scalars: (0..<20).map(Float.init))
72+
let output = layer.inferring(from: input)
73+
let expected = Tensor<Float>([[[[[9.5]]]]])
74+
XCTAssertEqual(output, expected)
75+
}
76+
4577
func testGlobalAvgPool1D() {
4678
let layer = GlobalAvgPool1D<Float>()
4779
let input = Tensor(shape: [2, 5, 1], scalars: (0..<10).map(Float.init))
@@ -150,7 +182,11 @@ final class LayerTests: XCTestCase {
150182
static var allTests = [
151183
("testConv1D", testConv1D),
152184
("testMaxPool1D", testMaxPool1D),
185+
("testMaxPool2D", testMaxPool2D),
186+
("testMaxPool3D", testMaxPool3D),
153187
("testAvgPool1D", testAvgPool1D),
188+
("testAvgPool2D", testAvgPool2D),
189+
("testAvgPool3D", testAvgPool3D),
154190
("testGlobalAvgPool1D", testGlobalAvgPool1D),
155191
("testGlobalAvgPool2D", testGlobalAvgPool2D),
156192
("testGlobalAvgPool3D", testGlobalAvgPool3D),

0 commit comments

Comments
 (0)