Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 6dc373a

Browse files
tanmayb123rxwei
authored andcommitted
Add Recurrent Layers (#71)
`SimpleRNN` and `LSTM`
1 parent 23c16ae commit 6dc373a

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,3 +1283,123 @@ public extension RNNCell {
12831283
return applied(to: RNNCellInput(input: input, state: state))
12841284
}
12851285
}
1286+
1287+
/// A Simple RNN Cell.
1288+
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1289+
public var weight: Tensor<Scalar>
1290+
public var bias: Tensor<Scalar>
1291+
1292+
@noDerivative public var stateShape: TensorShape {
1293+
return TensorShape([1, weight.shape[1]])
1294+
}
1295+
1296+
@differentiable
1297+
public var zeroState: Tensor<Scalar> {
1298+
return Tensor(zeros: stateShape)
1299+
}
1300+
1301+
public typealias State = Tensor<Scalar>
1302+
public typealias TimeStepInput = Tensor<Scalar>
1303+
public typealias TimeStepOutput = State
1304+
public typealias Input = RNNCellInput<TimeStepInput, State>
1305+
public typealias Output = RNNCellOutput<TimeStepOutput, State>
1306+
1307+
/// Creates a `SimpleRNNCell` with the specified input size and hidden state size.
1308+
///
1309+
/// - Parameters:
1310+
/// - inputSize: The number of features in 2-D input tensors.
1311+
/// - hiddenSize: The number of features in 2-D hidden states.
1312+
public init(inputSize: Int, hiddenSize: Int) {
1313+
let concatenatedInputSize = inputSize + hiddenSize
1314+
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize])
1315+
self.bias = Tensor(zeros: [hiddenSize])
1316+
}
1317+
1318+
/// Returns the output obtained from applying the layer to the given input.
1319+
///
1320+
/// - Parameters:
1321+
/// - input: The input to the layer.
1322+
/// - context: The contextual information for the layer application, e.g. the current learning
1323+
/// phase.
1324+
/// - Returns: The hidden state.
1325+
@differentiable
1326+
public func applied(to input: Input) -> Output {
1327+
let concatenatedInput = input.input.concatenated(with: input.state)
1328+
let newState = matmul(concatenatedInput, weight) + bias
1329+
return Output(output: newState, state: newState)
1330+
}
1331+
}
1332+
1333+
/// An LSTM Cell.
1334+
public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1335+
public var inputWeight, updateWeight, forgetWeight, outputWeight: Tensor<Scalar>
1336+
public var inputBias, updateBias, forgetBias, outputBias: Tensor<Scalar>
1337+
1338+
@noDerivative public var stateShape: TensorShape {
1339+
return TensorShape([1, inputWeight.shape[1]])
1340+
}
1341+
1342+
@differentiable
1343+
public var zeroState: State {
1344+
return State(cell: Tensor(zeros: stateShape), hidden: Tensor(zeros: stateShape))
1345+
}
1346+
1347+
public typealias TimeStepInput = Tensor<Scalar>
1348+
public typealias TimeStepOutput = State
1349+
public typealias Input = RNNCellInput<TimeStepInput, State>
1350+
public typealias Output = RNNCellOutput<TimeStepOutput, State>
1351+
1352+
/// Creates a `LSTMCell` with the specified input size and hidden state size.
1353+
///
1354+
/// - Parameters:
1355+
/// - inputSize: The number of features in 2-D input tensors.
1356+
/// - hiddenSize: The number of features in 2-D hidden states.
1357+
public init(inputSize: Int, hiddenSize: Int) {
1358+
let concatenatedInputSize = inputSize + hiddenSize
1359+
let gateWeightShape = TensorShape([concatenatedInputSize, hiddenSize])
1360+
let gateBiasShape = TensorShape([hiddenSize])
1361+
self.inputWeight = Tensor(glorotUniform: gateWeightShape)
1362+
self.inputBias = Tensor(zeros: gateBiasShape)
1363+
self.updateWeight = Tensor(glorotUniform: gateWeightShape)
1364+
self.updateBias = Tensor(zeros: gateBiasShape)
1365+
self.forgetWeight = Tensor(glorotUniform: gateWeightShape)
1366+
self.forgetBias = Tensor(ones: gateBiasShape)
1367+
self.outputWeight = Tensor(glorotUniform: gateWeightShape)
1368+
self.outputBias = Tensor(zeros: gateBiasShape)
1369+
}
1370+
1371+
public struct State: Differentiable {
1372+
public var cell: Tensor<Scalar>
1373+
public var hidden: Tensor<Scalar>
1374+
1375+
@differentiable
1376+
public init(cell: Tensor<Scalar>, hidden: Tensor<Scalar>) {
1377+
self.cell = cell
1378+
self.hidden = hidden
1379+
}
1380+
}
1381+
1382+
/// Returns the output obtained from applying the layer to the given input.
1383+
///
1384+
/// - Parameters:
1385+
/// - input: The input to the layer.
1386+
/// - context: The contextual information for the layer application, e.g. the current learning
1387+
/// phase.
1388+
/// - Returns: The hidden state.
1389+
@differentiable
1390+
public func applied(to input: Input) -> Output {
1391+
let gateInput = input.input.concatenated(with: input.state.hidden, alongAxis: 1)
1392+
1393+
let inputGate = sigmoid(matmul(gateInput, inputWeight) + inputBias)
1394+
let updateGate = tanh(matmul(gateInput, updateWeight) + updateBias)
1395+
let forgetGate = sigmoid(matmul(gateInput, forgetWeight) + forgetBias)
1396+
let outputGate = sigmoid(matmul(gateInput, outputWeight) + outputBias)
1397+
1398+
let newCellState = input.state.cell * forgetGate + inputGate * updateGate
1399+
let newHiddenState = tanh(newCellState) * outputGate
1400+
1401+
let newState = State(cell: newCellState, hidden: newHiddenState)
1402+
1403+
return Output(output: newState, state: newState)
1404+
}
1405+
}

Tests/DeepLearningTests/LayerTests.swift

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,19 @@ final class LayerTests: XCTestCase {
8282
XCTAssertEqual(output.shape, expected)
8383
}
8484

85+
func testSimpleRNNCell() {
86+
let weight = Tensor<Float>(ones: [7, 5]) * Tensor<Float>([0.3333, 1, 0.3333, 1, 0.3333])
87+
let bias = Tensor<Float>(ones: [5])
88+
var cell = SimpleRNNCell<Float>(inputSize: 2, hiddenSize: 5)
89+
cell.weight = weight
90+
cell.bias = bias
91+
let state = Tensor<Float>(ones: [1, 5]) * Tensor<Float>([1, 0.2, 0.5, 2, 0.6])
92+
let input = Tensor<Float>(ones: [1, 2]) * Tensor<Float>([0.3, 0.7])
93+
let output = cell.applied(to: input, state: state).state
94+
let expected = Tensor<Float>([[2.76649, 6.2999997, 2.76649, 6.2999997, 2.76649]])
95+
XCTAssertEqual(output, expected)
96+
}
97+
8598
static var allTests = [
8699
("testConv1D", testConv1D),
87100
("testMaxPool1D", testMaxPool1D),
@@ -90,6 +103,7 @@ final class LayerTests: XCTestCase {
90103
("testGlobalAvgPool2D", testGlobalAvgPool2D),
91104
("testGlobalAvgPool3D", testGlobalAvgPool3D),
92105
("testReshape", testReshape),
93-
("testFlatten", testFlatten)
106+
("testFlatten", testFlatten),
107+
("testSimpleRNNCell", testSimpleRNNCell)
94108
]
95109
}

0 commit comments

Comments
 (0)