@@ -1293,3 +1293,123 @@ public extension RNNCell {
1293
1293
return applied ( to: RNNCellInput ( input: input, state: state) )
1294
1294
}
1295
1295
}
1296
+
1297
+ /// A Simple RNN Cell.
1298
+ public struct SimpleRNNCell < Scalar: TensorFlowFloatingPoint > : RNNCell {
1299
+ public var weight : Tensor < Scalar >
1300
+ public var bias : Tensor < Scalar >
1301
+
1302
+ @noDerivative public var stateShape : TensorShape {
1303
+ return TensorShape ( [ 1 , weight. shape [ 1 ] ] )
1304
+ }
1305
+
1306
+ @differentiable
1307
+ public var zeroState : Tensor < Scalar > {
1308
+ return Tensor ( zeros: stateShape)
1309
+ }
1310
+
1311
+ public typealias State = Tensor < Scalar >
1312
+ public typealias TimeStepInput = Tensor < Scalar >
1313
+ public typealias TimeStepOutput = State
1314
+ public typealias Input = RNNCellInput < TimeStepInput , State >
1315
+ public typealias Output = RNNCellOutput < TimeStepOutput , State >
1316
+
1317
+ /// Creates a `SimpleRNNCell` with the specified input size and hidden state size.
1318
+ ///
1319
+ /// - Parameters:
1320
+ /// - inputSize: The number of features in 2-D input tensors.
1321
+ /// - hiddenSize: The number of features in 2-D hidden states.
1322
+ public init ( inputSize: Int , hiddenSize: Int ) {
1323
+ let concatenatedInputSize = inputSize + hiddenSize
1324
+ self . weight = Tensor ( glorotUniform: [ concatenatedInputSize, hiddenSize] )
1325
+ self . bias = Tensor ( zeros: [ hiddenSize] )
1326
+ }
1327
+
1328
+ /// Returns the output obtained from applying the layer to the given input.
1329
+ ///
1330
+ /// - Parameters:
1331
+ /// - input: The input to the layer.
1332
+ /// - context: The contextual information for the layer application, e.g. the current learning
1333
+ /// phase.
1334
+ /// - Returns: The hidden state.
1335
+ @differentiable
1336
+ public func applied( to input: Input ) -> Output {
1337
+ let concatenatedInput = input. input. concatenated ( with: input. state)
1338
+ let newState = matmul ( concatenatedInput, weight) + bias
1339
+ return Output ( output: newState, state: newState)
1340
+ }
1341
+ }
1342
+
1343
+ /// An LSTM Cell.
1344
+ public struct LSTMCell < Scalar: TensorFlowFloatingPoint > : RNNCell {
1345
+ public var inputWeight , updateWeight , forgetWeight , outputWeight : Tensor < Scalar >
1346
+ public var inputBias , updateBias , forgetBias , outputBias : Tensor < Scalar >
1347
+
1348
+ @noDerivative public var stateShape : TensorShape {
1349
+ return TensorShape ( [ 1 , inputWeight. shape [ 1 ] ] )
1350
+ }
1351
+
1352
+ @differentiable
1353
+ public var zeroState : State {
1354
+ return State ( cell: Tensor ( zeros: stateShape) , hidden: Tensor ( zeros: stateShape) )
1355
+ }
1356
+
1357
+ public typealias TimeStepInput = Tensor < Scalar >
1358
+ public typealias TimeStepOutput = State
1359
+ public typealias Input = RNNCellInput < TimeStepInput , State >
1360
+ public typealias Output = RNNCellOutput < TimeStepOutput , State >
1361
+
1362
+ /// Creates a `LSTMCell` with the specified input size and hidden state size.
1363
+ ///
1364
+ /// - Parameters:
1365
+ /// - inputSize: The number of features in 2-D input tensors.
1366
+ /// - hiddenSize: The number of features in 2-D hidden states.
1367
+ public init ( inputSize: Int , hiddenSize: Int ) {
1368
+ let concatenatedInputSize = inputSize + hiddenSize
1369
+ let gateWeightShape = TensorShape ( [ concatenatedInputSize, hiddenSize] )
1370
+ let gateBiasShape = TensorShape ( [ hiddenSize] )
1371
+ self . inputWeight = Tensor ( glorotUniform: gateWeightShape)
1372
+ self . inputBias = Tensor ( zeros: gateBiasShape)
1373
+ self . updateWeight = Tensor ( glorotUniform: gateWeightShape)
1374
+ self . updateBias = Tensor ( zeros: gateBiasShape)
1375
+ self . forgetWeight = Tensor ( glorotUniform: gateWeightShape)
1376
+ self . forgetBias = Tensor ( ones: gateBiasShape)
1377
+ self . outputWeight = Tensor ( glorotUniform: gateWeightShape)
1378
+ self . outputBias = Tensor ( zeros: gateBiasShape)
1379
+ }
1380
+
1381
+ public struct State : Differentiable {
1382
+ public var cell : Tensor < Scalar >
1383
+ public var hidden : Tensor < Scalar >
1384
+
1385
+ @differentiable
1386
+ public init ( cell: Tensor < Scalar > , hidden: Tensor < Scalar > ) {
1387
+ self . cell = cell
1388
+ self . hidden = hidden
1389
+ }
1390
+ }
1391
+
1392
+ /// Returns the output obtained from applying the layer to the given input.
1393
+ ///
1394
+ /// - Parameters:
1395
+ /// - input: The input to the layer.
1396
+ /// - context: The contextual information for the layer application, e.g. the current learning
1397
+ /// phase.
1398
+ /// - Returns: The hidden state.
1399
+ @differentiable
1400
+ public func applied( to input: Input ) -> Output {
1401
+ let gateInput = input. input. concatenated ( with: input. state. hidden, alongAxis: 1 )
1402
+
1403
+ let inputGate = sigmoid ( matmul ( gateInput, inputWeight) + inputBias)
1404
+ let updateGate = tanh ( matmul ( gateInput, updateWeight) + updateBias)
1405
+ let forgetGate = sigmoid ( matmul ( gateInput, forgetWeight) + forgetBias)
1406
+ let outputGate = sigmoid ( matmul ( gateInput, outputWeight) + outputBias)
1407
+
1408
+ let newCellState = input. state. cell * forgetGate + inputGate * updateGate
1409
+ let newHiddenState = tanh ( newCellState) * outputGate
1410
+
1411
+ let newState = State ( cell: newCellState, hidden: newHiddenState)
1412
+
1413
+ return Output ( output: newState, state: newState)
1414
+ }
1415
+ }
0 commit comments