Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 9b3b7e5

Browse files
Shashi456rxwei
authored andcommitted
Add GlobalMaxPool1D, GlobalMaxPool2D, GlobalMaxPool3D (#76)
1 parent c1b18f7 commit 9b3b7e5

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

Sources/TensorFlow/Layers/Pooling.swift

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,54 @@ public struct GlobalAvgPool3D<Scalar: TensorFlowFloatingPoint>: Layer {
340340
return input.mean(squeezingAxes: [1, 2, 3])
341341
}
342342
}
343+
344+
/// A global max pooling layer for temporal data.
345+
@_fixed_layout
346+
public struct GlobalMaxPool1D<Scalar: TensorFlowFloatingPoint>: Layer {
347+
/// Creates a global max pooling layer.
348+
public init() {}
349+
350+
/// Returns the output obtained from applying the layer to the given input.
351+
///
352+
/// - Parameters:
353+
/// - input: The input to the layer.
354+
/// - context: The contextual information for the layer application, e.g. the current learning
355+
/// phase.
356+
/// - Returns: The output.
357+
@differentiable
358+
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
359+
return input.max(squeezingAxes: 1)
360+
}
361+
}
362+
363+
/// A global max pooling layer for spatial data.
364+
@_fixed_layout
365+
public struct GlobalMaxPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
366+
/// Creates a global max pooling layer.
367+
public init() {}
368+
369+
/// Returns the output obtained from applying the layer to the given input.
370+
///
371+
/// - Parameter input: The input to the layer.
372+
/// - Returns: The output.
373+
@differentiable
374+
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
375+
return input.max(squeezingAxes: [1, 2])
376+
}
377+
}
378+
379+
/// A global max pooling layer for spatial and spatio-temporal data.
380+
@_fixed_layout
381+
public struct GlobalMaxPool3D<Scalar: TensorFlowFloatingPoint>: Layer {
382+
/// Creates a global max pooling layer.
383+
public init() {}
384+
385+
/// Returns the output obtained from applying the layer to the given input.
386+
///
387+
/// - Parameter input: The input to the layer.
388+
/// - Returns: The output.
389+
@differentiable
390+
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
391+
return input.max(squeezingAxes: [1, 2, 3])
392+
}
393+
}

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,30 @@ final class LayerTests: XCTestCase {
124124
XCTAssertEqual(output, expected)
125125
}
126126

127+
func testGlobalMaxPool1D() {
128+
let layer = GlobalMaxPool1D<Float>()
129+
let input = Tensor(shape: [1, 10, 1], scalars: (0..<10).map(Float.init))
130+
let output = layer.inferring(from: input)
131+
let expected = Tensor<Float>([9])
132+
XCTAssertEqual(output, expected)
133+
}
134+
135+
func testGlobalMaxPool2D() {
136+
let layer = GlobalMaxPool2D<Float>()
137+
let input = Tensor(shape: [1, 2, 10, 1], scalars: (0..<20).map(Float.init))
138+
let output = layer.inferring(from: input)
139+
let expected = Tensor<Float>([19])
140+
XCTAssertEqual(output, expected)
141+
}
142+
143+
func testGlobalMaxPool3D() {
144+
let layer = GlobalMaxPool3D<Float>()
145+
let input = Tensor<Float>(shape: [1, 2, 3, 5, 1], scalars: (0..<30).map(Float.init))
146+
let output = layer.inferring(from: input)
147+
let expected = Tensor<Float>([29])
148+
XCTAssertEqual(output, expected)
149+
}
150+
127151
func testUpSampling1D() {
128152
let size = 6
129153
let layer = UpSampling1D<Float>(size: size)
@@ -226,6 +250,9 @@ final class LayerTests: XCTestCase {
226250
("testGlobalAvgPool1D", testGlobalAvgPool1D),
227251
("testGlobalAvgPool2D", testGlobalAvgPool2D),
228252
("testGlobalAvgPool3D", testGlobalAvgPool3D),
253+
("testGlobalMaxPool1D", testGlobalMaxPool1D),
254+
("testGlobalMaxPool2D", testGlobalMaxPool2D),
255+
("testGlobalMaxPool3D", testGlobalMaxPool3D),
229256
("testUpSampling1D", testUpSampling1D),
230257
("testUpSampling2D", testUpSampling2D),
231258
("testUpSampling3D", testUpSampling3D),

0 commit comments

Comments
 (0)