Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 904dd50

Browse files
dan-zhengrxwei
authored andcommitted
Revert "Revert "Add Recurrent Layers"" (#97)
* Revert "Revert "Add Recurrent Layers (#71)" (#94)" This reverts commit f75c5e0. * Remove `@differentiable` from `zeroState`. * Fix axis of concatenation.
1 parent 3139e9c commit 904dd50

File tree

2 files changed

+133
-2
lines changed

2 files changed

+133
-2
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1276,7 +1276,6 @@ public protocol RNNCell: Layer where Input == RNNCellInput<TimeStepInput, State>
12761276
/// The state that may be preserved across time steps.
12771277
associatedtype State: Differentiable
12781278
/// The zero state.
1279-
@differentiable
12801279
var zeroState: State { get }
12811280
}
12821281

@@ -1293,3 +1292,121 @@ public extension RNNCell {
12931292
return applied(to: RNNCellInput(input: input, state: state))
12941293
}
12951294
}
1295+
1296+
/// A Simple RNN Cell.
1297+
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1298+
public var weight: Tensor<Scalar>
1299+
public var bias: Tensor<Scalar>
1300+
1301+
@noDerivative public var stateShape: TensorShape {
1302+
return TensorShape([1, weight.shape[1]])
1303+
}
1304+
1305+
public var zeroState: Tensor<Scalar> {
1306+
return Tensor(zeros: stateShape)
1307+
}
1308+
1309+
public typealias State = Tensor<Scalar>
1310+
public typealias TimeStepInput = Tensor<Scalar>
1311+
public typealias TimeStepOutput = State
1312+
public typealias Input = RNNCellInput<TimeStepInput, State>
1313+
public typealias Output = RNNCellOutput<TimeStepOutput, State>
1314+
1315+
/// Creates a `SimpleRNNCell` with the specified input size and hidden state size.
1316+
///
1317+
/// - Parameters:
1318+
/// - inputSize: The number of features in 2-D input tensors.
1319+
/// - hiddenSize: The number of features in 2-D hidden states.
1320+
public init(inputSize: Int, hiddenSize: Int) {
1321+
let concatenatedInputSize = inputSize + hiddenSize
1322+
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize])
1323+
self.bias = Tensor(zeros: [hiddenSize])
1324+
}
1325+
1326+
/// Returns the output obtained from applying the layer to the given input.
1327+
///
1328+
/// - Parameters:
1329+
/// - input: The input to the layer.
1330+
/// - context: The contextual information for the layer application, e.g. the current learning
1331+
/// phase.
1332+
/// - Returns: The hidden state.
1333+
@differentiable
1334+
public func applied(to input: Input) -> Output {
1335+
let concatenatedInput = input.input.concatenated(with: input.state, alongAxis: 1)
1336+
let newState = matmul(concatenatedInput, weight) + bias
1337+
return Output(output: newState, state: newState)
1338+
}
1339+
}
1340+
1341+
/// An LSTM Cell.
1342+
public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1343+
public var inputWeight, updateWeight, forgetWeight, outputWeight: Tensor<Scalar>
1344+
public var inputBias, updateBias, forgetBias, outputBias: Tensor<Scalar>
1345+
1346+
@noDerivative public var stateShape: TensorShape {
1347+
return TensorShape([1, inputWeight.shape[1]])
1348+
}
1349+
1350+
public var zeroState: State {
1351+
return State(cell: Tensor(zeros: stateShape), hidden: Tensor(zeros: stateShape))
1352+
}
1353+
1354+
public typealias TimeStepInput = Tensor<Scalar>
1355+
public typealias TimeStepOutput = State
1356+
public typealias Input = RNNCellInput<TimeStepInput, State>
1357+
public typealias Output = RNNCellOutput<TimeStepOutput, State>
1358+
1359+
/// Creates a `LSTMCell` with the specified input size and hidden state size.
1360+
///
1361+
/// - Parameters:
1362+
/// - inputSize: The number of features in 2-D input tensors.
1363+
/// - hiddenSize: The number of features in 2-D hidden states.
1364+
public init(inputSize: Int, hiddenSize: Int) {
1365+
let concatenatedInputSize = inputSize + hiddenSize
1366+
let gateWeightShape = TensorShape([concatenatedInputSize, hiddenSize])
1367+
let gateBiasShape = TensorShape([hiddenSize])
1368+
self.inputWeight = Tensor(glorotUniform: gateWeightShape)
1369+
self.inputBias = Tensor(zeros: gateBiasShape)
1370+
self.updateWeight = Tensor(glorotUniform: gateWeightShape)
1371+
self.updateBias = Tensor(zeros: gateBiasShape)
1372+
self.forgetWeight = Tensor(glorotUniform: gateWeightShape)
1373+
self.forgetBias = Tensor(ones: gateBiasShape)
1374+
self.outputWeight = Tensor(glorotUniform: gateWeightShape)
1375+
self.outputBias = Tensor(zeros: gateBiasShape)
1376+
}
1377+
1378+
public struct State: Differentiable {
1379+
public var cell: Tensor<Scalar>
1380+
public var hidden: Tensor<Scalar>
1381+
1382+
@differentiable
1383+
public init(cell: Tensor<Scalar>, hidden: Tensor<Scalar>) {
1384+
self.cell = cell
1385+
self.hidden = hidden
1386+
}
1387+
}
1388+
1389+
/// Returns the output obtained from applying the layer to the given input.
1390+
///
1391+
/// - Parameters:
1392+
/// - input: The input to the layer.
1393+
/// - context: The contextual information for the layer application, e.g. the current learning
1394+
/// phase.
1395+
/// - Returns: The hidden state.
1396+
@differentiable
1397+
public func applied(to input: Input) -> Output {
1398+
let gateInput = input.input.concatenated(with: input.state.hidden, alongAxis: 1)
1399+
1400+
let inputGate = sigmoid(matmul(gateInput, inputWeight) + inputBias)
1401+
let updateGate = tanh(matmul(gateInput, updateWeight) + updateBias)
1402+
let forgetGate = sigmoid(matmul(gateInput, forgetWeight) + forgetBias)
1403+
let outputGate = sigmoid(matmul(gateInput, outputWeight) + outputBias)
1404+
1405+
let newCellState = input.state.cell * forgetGate + inputGate * updateGate
1406+
let newHiddenState = tanh(newCellState) * outputGate
1407+
1408+
let newState = State(cell: newCellState, hidden: newHiddenState)
1409+
1410+
return Output(output: newState, state: newState)
1411+
}
1412+
}

Tests/DeepLearningTests/LayerTests.swift

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,19 @@ final class LayerTests: XCTestCase {
8282
XCTAssertEqual(output.shape, expected)
8383
}
8484

85+
func testSimpleRNNCell() {
86+
let weight = Tensor<Float>(ones: [7, 5]) * Tensor<Float>([0.3333, 1, 0.3333, 1, 0.3333])
87+
let bias = Tensor<Float>(ones: [5])
88+
var cell = SimpleRNNCell<Float>(inputSize: 2, hiddenSize: 5)
89+
cell.weight = weight
90+
cell.bias = bias
91+
let state = Tensor<Float>(ones: [1, 5]) * Tensor<Float>([1, 0.2, 0.5, 2, 0.6])
92+
let input = Tensor<Float>(ones: [1, 2]) * Tensor<Float>([0.3, 0.7])
93+
let output = cell.applied(to: input, state: state).state
94+
let expected = Tensor<Float>([[2.76649, 6.2999997, 2.76649, 6.2999997, 2.76649]])
95+
XCTAssertEqual(output, expected)
96+
}
97+
8598
static var allTests = [
8699
("testConv1D", testConv1D),
87100
("testMaxPool1D", testMaxPool1D),
@@ -90,6 +103,7 @@ final class LayerTests: XCTestCase {
90103
("testGlobalAvgPool2D", testGlobalAvgPool2D),
91104
("testGlobalAvgPool3D", testGlobalAvgPool3D),
92105
("testReshape", testReshape),
93-
("testFlatten", testFlatten)
106+
("testFlatten", testFlatten),
107+
("testSimpleRNNCell", testSimpleRNNCell)
94108
]
95109
}

0 commit comments

Comments
 (0)