Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 7b2d77f

Browse files
authored
Revert "Revert "Add Recurrent Layers (#71)" (#94)"
This reverts commit f75c5e0.
1 parent db48292 commit 7b2d77f

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,3 +1293,123 @@ public extension RNNCell {
12931293
return applied(to: RNNCellInput(input: input, state: state))
12941294
}
12951295
}
1296+
1297+
/// A Simple RNN Cell.
1298+
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1299+
public var weight: Tensor<Scalar>
1300+
public var bias: Tensor<Scalar>
1301+
1302+
@noDerivative public var stateShape: TensorShape {
1303+
return TensorShape([1, weight.shape[1]])
1304+
}
1305+
1306+
@differentiable
1307+
public var zeroState: Tensor<Scalar> {
1308+
return Tensor(zeros: stateShape)
1309+
}
1310+
1311+
public typealias State = Tensor<Scalar>
1312+
public typealias TimeStepInput = Tensor<Scalar>
1313+
public typealias TimeStepOutput = State
1314+
public typealias Input = RNNCellInput<TimeStepInput, State>
1315+
public typealias Output = RNNCellOutput<TimeStepOutput, State>
1316+
1317+
/// Creates a `SimpleRNNCell` with the specified input size and hidden state size.
1318+
///
1319+
/// - Parameters:
1320+
/// - inputSize: The number of features in 2-D input tensors.
1321+
/// - hiddenSize: The number of features in 2-D hidden states.
1322+
public init(inputSize: Int, hiddenSize: Int) {
1323+
let concatenatedInputSize = inputSize + hiddenSize
1324+
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize])
1325+
self.bias = Tensor(zeros: [hiddenSize])
1326+
}
1327+
1328+
/// Returns the output obtained from applying the layer to the given input.
1329+
///
1330+
/// - Parameters:
1331+
/// - input: The input to the layer.
1332+
/// - context: The contextual information for the layer application, e.g. the current learning
1333+
/// phase.
1334+
/// - Returns: The hidden state.
1335+
@differentiable
1336+
public func applied(to input: Input) -> Output {
1337+
let concatenatedInput = input.input.concatenated(with: input.state)
1338+
let newState = matmul(concatenatedInput, weight) + bias
1339+
return Output(output: newState, state: newState)
1340+
}
1341+
}
1342+
1343+
/// An LSTM Cell.
1344+
public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1345+
public var inputWeight, updateWeight, forgetWeight, outputWeight: Tensor<Scalar>
1346+
public var inputBias, updateBias, forgetBias, outputBias: Tensor<Scalar>
1347+
1348+
@noDerivative public var stateShape: TensorShape {
1349+
return TensorShape([1, inputWeight.shape[1]])
1350+
}
1351+
1352+
@differentiable
1353+
public var zeroState: State {
1354+
return State(cell: Tensor(zeros: stateShape), hidden: Tensor(zeros: stateShape))
1355+
}
1356+
1357+
public typealias TimeStepInput = Tensor<Scalar>
1358+
public typealias TimeStepOutput = State
1359+
public typealias Input = RNNCellInput<TimeStepInput, State>
1360+
public typealias Output = RNNCellOutput<TimeStepOutput, State>
1361+
1362+
/// Creates a `LSTMCell` with the specified input size and hidden state size.
1363+
///
1364+
/// - Parameters:
1365+
/// - inputSize: The number of features in 2-D input tensors.
1366+
/// - hiddenSize: The number of features in 2-D hidden states.
1367+
public init(inputSize: Int, hiddenSize: Int) {
1368+
let concatenatedInputSize = inputSize + hiddenSize
1369+
let gateWeightShape = TensorShape([concatenatedInputSize, hiddenSize])
1370+
let gateBiasShape = TensorShape([hiddenSize])
1371+
self.inputWeight = Tensor(glorotUniform: gateWeightShape)
1372+
self.inputBias = Tensor(zeros: gateBiasShape)
1373+
self.updateWeight = Tensor(glorotUniform: gateWeightShape)
1374+
self.updateBias = Tensor(zeros: gateBiasShape)
1375+
self.forgetWeight = Tensor(glorotUniform: gateWeightShape)
1376+
self.forgetBias = Tensor(ones: gateBiasShape)
1377+
self.outputWeight = Tensor(glorotUniform: gateWeightShape)
1378+
self.outputBias = Tensor(zeros: gateBiasShape)
1379+
}
1380+
1381+
public struct State: Differentiable {
1382+
public var cell: Tensor<Scalar>
1383+
public var hidden: Tensor<Scalar>
1384+
1385+
@differentiable
1386+
public init(cell: Tensor<Scalar>, hidden: Tensor<Scalar>) {
1387+
self.cell = cell
1388+
self.hidden = hidden
1389+
}
1390+
}
1391+
1392+
/// Returns the output obtained from applying the layer to the given input.
1393+
///
1394+
/// - Parameters:
1395+
/// - input: The input to the layer.
1396+
/// - context: The contextual information for the layer application, e.g. the current learning
1397+
/// phase.
1398+
/// - Returns: The hidden state.
1399+
@differentiable
1400+
public func applied(to input: Input) -> Output {
1401+
let gateInput = input.input.concatenated(with: input.state.hidden, alongAxis: 1)
1402+
1403+
let inputGate = sigmoid(matmul(gateInput, inputWeight) + inputBias)
1404+
let updateGate = tanh(matmul(gateInput, updateWeight) + updateBias)
1405+
let forgetGate = sigmoid(matmul(gateInput, forgetWeight) + forgetBias)
1406+
let outputGate = sigmoid(matmul(gateInput, outputWeight) + outputBias)
1407+
1408+
let newCellState = input.state.cell * forgetGate + inputGate * updateGate
1409+
let newHiddenState = tanh(newCellState) * outputGate
1410+
1411+
let newState = State(cell: newCellState, hidden: newHiddenState)
1412+
1413+
return Output(output: newState, state: newState)
1414+
}
1415+
}

Tests/DeepLearningTests/LayerTests.swift

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,19 @@ final class LayerTests: XCTestCase {
8282
XCTAssertEqual(output.shape, expected)
8383
}
8484

85+
func testSimpleRNNCell() {
86+
let weight = Tensor<Float>(ones: [7, 5]) * Tensor<Float>([0.3333, 1, 0.3333, 1, 0.3333])
87+
let bias = Tensor<Float>(ones: [5])
88+
var cell = SimpleRNNCell<Float>(inputSize: 2, hiddenSize: 5)
89+
cell.weight = weight
90+
cell.bias = bias
91+
let state = Tensor<Float>(ones: [1, 5]) * Tensor<Float>([1, 0.2, 0.5, 2, 0.6])
92+
let input = Tensor<Float>(ones: [1, 2]) * Tensor<Float>([0.3, 0.7])
93+
let output = cell.applied(to: input, state: state).state
94+
let expected = Tensor<Float>([[2.76649, 6.2999997, 2.76649, 6.2999997, 2.76649]])
95+
XCTAssertEqual(output, expected)
96+
}
97+
8598
static var allTests = [
8699
("testConv1D", testConv1D),
87100
("testMaxPool1D", testMaxPool1D),
@@ -90,6 +103,7 @@ final class LayerTests: XCTestCase {
90103
("testGlobalAvgPool2D", testGlobalAvgPool2D),
91104
("testGlobalAvgPool3D", testGlobalAvgPool3D),
92105
("testReshape", testReshape),
93-
("testFlatten", testFlatten)
106+
("testFlatten", testFlatten),
107+
("testSimpleRNNCell", testSimpleRNNCell)
94108
]
95109
}

0 commit comments

Comments
 (0)