Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1dc4afb

Browse files
Shashi456rxwei
authored andcommitted
Adding Upsampling 3D layer and tests for upsampling layers (#112)
1 parent 17c58d7 commit 1dc4afb

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,35 @@ public struct UpSampling2D<Scalar: TensorFlowFloatingPoint>: Layer {
11521152
}
11531153
}
11541154

1155+
/// An upsampling layer for 3-D inputs.
1156+
@_fixed_layout
1157+
public struct UpSampling3D<Scalar: TensorFlowFloatingPoint>: Layer {
1158+
@noDerivative public let size: Int
1159+
1160+
/// Creates an upsampling layer.
1161+
///
1162+
/// - Parameter size: The upsampling factor for rows and columns.
1163+
public init(size: Int) {
1164+
self.size = size
1165+
}
1166+
1167+
/// Returns the output obtained from applying the layer to the given input.
1168+
///
1169+
/// - Parameter input: The input to the layer.
1170+
/// - Returns: The output.
1171+
@differentiable
1172+
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
1173+
let shape = input.shape
1174+
let (batchSize, height, width, depth, channels) =
1175+
(shape[0], shape[1], shape[2], shape[3], shape[4])
1176+
let scaleOnes = Tensor<Scalar>(ones: [1, 1, size, 1, size, 1, size, 1])
1177+
let upSampling = input.reshaped(
1178+
to: [batchSize, height, 1, width, 1, depth, 1, channels]) * scaleOnes
1179+
return upSampling.reshaped(
1180+
to: [batchSize, height * size, width * size, depth * size, channels])
1181+
}
1182+
}
1183+
11551184
/// A flatten layer.
11561185
///
11571186
/// A flatten layer flattens the input when applied without affecting the batch size.
@@ -1384,7 +1413,7 @@ public struct RNN<Cell: RNNCell>: Layer {
13841413
public typealias Output = [Cell.TimeStepOutput]
13851414

13861415
public var cell: Cell
1387-
1416+
13881417
public init(_ cell: @autoclosure () -> Cell) {
13891418
self.cell = cell()
13901419
}

Tests/DeepLearningTests/LayerTests.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,33 @@ final class LayerTests: XCTestCase {
6666
XCTAssertEqual(output, expected)
6767
}
6868

69+
func testUpSampling1D() {
70+
let size = 6
71+
let layer = UpSampling1D<Float>(size: size)
72+
let input = Tensor<Float>(shape: [1, 10, 1], scalars: (0..<10).map(Float.init))
73+
let output = layer.inferring(from: input)
74+
let expected = TensorShape([1, input.shape[1] * size, 1])
75+
XCTAssertEqual(output.shape, expected)
76+
}
77+
78+
func testUpSampling2D() {
79+
let size = 6
80+
let layer = UpSampling2D<Float>(size: size)
81+
let input = Tensor<Float>(shape: [1, 3, 5, 1], scalars: (0..<15).map(Float.init))
82+
let output = layer.inferring(from: input)
83+
let expected = TensorShape([1, input.shape[1] * size, input.shape[2] * size, 1])
84+
XCTAssertEqual(output.shape, expected)
85+
}
86+
87+
func testUpSampling3D() {
88+
let size = 6
89+
let layer = UpSampling3D<Float>(size: size)
90+
let input = Tensor<Float>(shape: [1, 4, 3, 2, 1], scalars: (0..<24).map(Float.init))
91+
let output = layer.inferring(from: input)
92+
let expected = TensorShape([1, input.shape[1] * size, input.shape[2] * size, input.shape[3] * size, 1])
93+
XCTAssertEqual(output.shape, expected)
94+
}
95+
6996
func testReshape() {
7097
let layer = Reshape<Float>(shape: [10, 2, 1])
7198
let input = Tensor(shape: [20, 1], scalars: (0..<20).map(Float.init))
@@ -127,6 +154,9 @@ final class LayerTests: XCTestCase {
127154
("testGlobalAvgPool1D", testGlobalAvgPool1D),
128155
("testGlobalAvgPool2D", testGlobalAvgPool2D),
129156
("testGlobalAvgPool3D", testGlobalAvgPool3D),
157+
("testUpSampling1D", testUpSampling1D),
158+
("testUpSampling2D", testUpSampling2D),
159+
("testUpSampling3D", testUpSampling3D),
130160
("testReshape", testReshape),
131161
("testFlatten", testFlatten),
132162
("testSimpleRNNCell", testSimpleRNNCell),

0 commit comments

Comments
 (0)