@@ -1276,7 +1276,6 @@ public protocol RNNCell: Layer where Input == RNNCellInput<TimeStepInput, State>
1276
1276
/// The state that may be preserved across time steps.
1277
1277
associatedtype State : Differentiable
1278
1278
/// The zero state.
1279
- @differentiable
1280
1279
var zeroState : State { get }
1281
1280
}
1282
1281
@@ -1293,3 +1292,121 @@ public extension RNNCell {
1293
1292
return applied ( to: RNNCellInput ( input: input, state: state) )
1294
1293
}
1295
1294
}
1295
+
1296
+ /// A Simple RNN Cell.
1297
+ public struct SimpleRNNCell < Scalar: TensorFlowFloatingPoint > : RNNCell {
1298
+ public var weight : Tensor < Scalar >
1299
+ public var bias : Tensor < Scalar >
1300
+
1301
+ @noDerivative public var stateShape : TensorShape {
1302
+ return TensorShape ( [ 1 , weight. shape [ 1 ] ] )
1303
+ }
1304
+
1305
+ public var zeroState : Tensor < Scalar > {
1306
+ return Tensor ( zeros: stateShape)
1307
+ }
1308
+
1309
+ public typealias State = Tensor < Scalar >
1310
+ public typealias TimeStepInput = Tensor < Scalar >
1311
+ public typealias TimeStepOutput = State
1312
+ public typealias Input = RNNCellInput < TimeStepInput , State >
1313
+ public typealias Output = RNNCellOutput < TimeStepOutput , State >
1314
+
1315
+ /// Creates a `SimpleRNNCell` with the specified input size and hidden state size.
1316
+ ///
1317
+ /// - Parameters:
1318
+ /// - inputSize: The number of features in 2-D input tensors.
1319
+ /// - hiddenSize: The number of features in 2-D hidden states.
1320
+ public init ( inputSize: Int , hiddenSize: Int ) {
1321
+ let concatenatedInputSize = inputSize + hiddenSize
1322
+ self . weight = Tensor ( glorotUniform: [ concatenatedInputSize, hiddenSize] )
1323
+ self . bias = Tensor ( zeros: [ hiddenSize] )
1324
+ }
1325
+
1326
+ /// Returns the output obtained from applying the layer to the given input.
1327
+ ///
1328
+ /// - Parameters:
1329
+ /// - input: The input to the layer.
1330
+ /// - context: The contextual information for the layer application, e.g. the current learning
1331
+ /// phase.
1332
+ /// - Returns: The hidden state.
1333
+ @differentiable
1334
+ public func applied( to input: Input ) -> Output {
1335
+ let concatenatedInput = input. input. concatenated ( with: input. state, alongAxis: 1 )
1336
+ let newState = matmul ( concatenatedInput, weight) + bias
1337
+ return Output ( output: newState, state: newState)
1338
+ }
1339
+ }
1340
+
1341
+ /// An LSTM Cell.
1342
+ public struct LSTMCell < Scalar: TensorFlowFloatingPoint > : RNNCell {
1343
+ public var inputWeight , updateWeight , forgetWeight , outputWeight : Tensor < Scalar >
1344
+ public var inputBias , updateBias , forgetBias , outputBias : Tensor < Scalar >
1345
+
1346
+ @noDerivative public var stateShape : TensorShape {
1347
+ return TensorShape ( [ 1 , inputWeight. shape [ 1 ] ] )
1348
+ }
1349
+
1350
+ public var zeroState : State {
1351
+ return State ( cell: Tensor ( zeros: stateShape) , hidden: Tensor ( zeros: stateShape) )
1352
+ }
1353
+
1354
+ public typealias TimeStepInput = Tensor < Scalar >
1355
+ public typealias TimeStepOutput = State
1356
+ public typealias Input = RNNCellInput < TimeStepInput , State >
1357
+ public typealias Output = RNNCellOutput < TimeStepOutput , State >
1358
+
1359
+ /// Creates a `LSTMCell` with the specified input size and hidden state size.
1360
+ ///
1361
+ /// - Parameters:
1362
+ /// - inputSize: The number of features in 2-D input tensors.
1363
+ /// - hiddenSize: The number of features in 2-D hidden states.
1364
+ public init ( inputSize: Int , hiddenSize: Int ) {
1365
+ let concatenatedInputSize = inputSize + hiddenSize
1366
+ let gateWeightShape = TensorShape ( [ concatenatedInputSize, hiddenSize] )
1367
+ let gateBiasShape = TensorShape ( [ hiddenSize] )
1368
+ self . inputWeight = Tensor ( glorotUniform: gateWeightShape)
1369
+ self . inputBias = Tensor ( zeros: gateBiasShape)
1370
+ self . updateWeight = Tensor ( glorotUniform: gateWeightShape)
1371
+ self . updateBias = Tensor ( zeros: gateBiasShape)
1372
+ self . forgetWeight = Tensor ( glorotUniform: gateWeightShape)
1373
+ self . forgetBias = Tensor ( ones: gateBiasShape)
1374
+ self . outputWeight = Tensor ( glorotUniform: gateWeightShape)
1375
+ self . outputBias = Tensor ( zeros: gateBiasShape)
1376
+ }
1377
+
1378
+ public struct State : Differentiable {
1379
+ public var cell : Tensor < Scalar >
1380
+ public var hidden : Tensor < Scalar >
1381
+
1382
+ @differentiable
1383
+ public init ( cell: Tensor < Scalar > , hidden: Tensor < Scalar > ) {
1384
+ self . cell = cell
1385
+ self . hidden = hidden
1386
+ }
1387
+ }
1388
+
1389
+ /// Returns the output obtained from applying the layer to the given input.
1390
+ ///
1391
+ /// - Parameters:
1392
+ /// - input: The input to the layer.
1393
+ /// - context: The contextual information for the layer application, e.g. the current learning
1394
+ /// phase.
1395
+ /// - Returns: The hidden state.
1396
+ @differentiable
1397
+ public func applied( to input: Input ) -> Output {
1398
+ let gateInput = input. input. concatenated ( with: input. state. hidden, alongAxis: 1 )
1399
+
1400
+ let inputGate = sigmoid ( matmul ( gateInput, inputWeight) + inputBias)
1401
+ let updateGate = tanh ( matmul ( gateInput, updateWeight) + updateBias)
1402
+ let forgetGate = sigmoid ( matmul ( gateInput, forgetWeight) + forgetBias)
1403
+ let outputGate = sigmoid ( matmul ( gateInput, outputWeight) + outputBias)
1404
+
1405
+ let newCellState = input. state. cell * forgetGate + inputGate * updateGate
1406
+ let newHiddenState = tanh ( newCellState) * outputGate
1407
+
1408
+ let newState = State ( cell: newCellState, hidden: newHiddenState)
1409
+
1410
+ return Output ( output: newState, state: newState)
1411
+ }
1412
+ }
0 commit comments