@@ -1283,3 +1283,123 @@ public extension RNNCell {
1283
1283
return applied ( to: RNNCellInput ( input: input, state: state) )
1284
1284
}
1285
1285
}
1286
+
1287
+ /// A Simple RNN Cell.
1288
+ public struct SimpleRNNCell < Scalar: TensorFlowFloatingPoint > : RNNCell {
1289
+ public var weight : Tensor < Scalar >
1290
+ public var bias : Tensor < Scalar >
1291
+
1292
+ @noDerivative public var stateShape : TensorShape {
1293
+ return TensorShape ( [ 1 , weight. shape [ 1 ] ] )
1294
+ }
1295
+
1296
+ @differentiable
1297
+ public var zeroState : Tensor < Scalar > {
1298
+ return Tensor ( zeros: stateShape)
1299
+ }
1300
+
1301
+ public typealias State = Tensor < Scalar >
1302
+ public typealias TimeStepInput = Tensor < Scalar >
1303
+ public typealias TimeStepOutput = State
1304
+ public typealias Input = RNNCellInput < TimeStepInput , State >
1305
+ public typealias Output = RNNCellOutput < TimeStepOutput , State >
1306
+
1307
+ /// Creates a `SimpleRNNCell` with the specified input size and hidden state size.
1308
+ ///
1309
+ /// - Parameters:
1310
+ /// - inputSize: The number of features in 2-D input tensors.
1311
+ /// - hiddenSize: The number of features in 2-D hidden states.
1312
+ public init ( inputSize: Int , hiddenSize: Int ) {
1313
+ let concatenatedInputSize = inputSize + hiddenSize
1314
+ self . weight = Tensor ( glorotUniform: [ concatenatedInputSize, hiddenSize] )
1315
+ self . bias = Tensor ( zeros: [ hiddenSize] )
1316
+ }
1317
+
1318
+ /// Returns the output obtained from applying the layer to the given input.
1319
+ ///
1320
+ /// - Parameters:
1321
+ /// - input: The input to the layer.
1322
+ /// - context: The contextual information for the layer application, e.g. the current learning
1323
+ /// phase.
1324
+ /// - Returns: The hidden state.
1325
+ @differentiable
1326
+ public func applied( to input: Input ) -> Output {
1327
+ let concatenatedInput = input. input. concatenated ( with: input. state)
1328
+ let newState = matmul ( concatenatedInput, weight) + bias
1329
+ return Output ( output: newState, state: newState)
1330
+ }
1331
+ }
1332
+
1333
+ /// An LSTM Cell.
1334
+ public struct LSTMCell < Scalar: TensorFlowFloatingPoint > : RNNCell {
1335
+ public var inputWeight , updateWeight , forgetWeight , outputWeight : Tensor < Scalar >
1336
+ public var inputBias , updateBias , forgetBias , outputBias : Tensor < Scalar >
1337
+
1338
+ @noDerivative public var stateShape : TensorShape {
1339
+ return TensorShape ( [ 1 , inputWeight. shape [ 1 ] ] )
1340
+ }
1341
+
1342
+ @differentiable
1343
+ public var zeroState : State {
1344
+ return State ( cell: Tensor ( zeros: stateShape) , hidden: Tensor ( zeros: stateShape) )
1345
+ }
1346
+
1347
+ public typealias TimeStepInput = Tensor < Scalar >
1348
+ public typealias TimeStepOutput = State
1349
+ public typealias Input = RNNCellInput < TimeStepInput , State >
1350
+ public typealias Output = RNNCellOutput < TimeStepOutput , State >
1351
+
1352
+ /// Creates a `LSTMCell` with the specified input size and hidden state size.
1353
+ ///
1354
+ /// - Parameters:
1355
+ /// - inputSize: The number of features in 2-D input tensors.
1356
+ /// - hiddenSize: The number of features in 2-D hidden states.
1357
+ public init ( inputSize: Int , hiddenSize: Int ) {
1358
+ let concatenatedInputSize = inputSize + hiddenSize
1359
+ let gateWeightShape = TensorShape ( [ concatenatedInputSize, hiddenSize] )
1360
+ let gateBiasShape = TensorShape ( [ hiddenSize] )
1361
+ self . inputWeight = Tensor ( glorotUniform: gateWeightShape)
1362
+ self . inputBias = Tensor ( zeros: gateBiasShape)
1363
+ self . updateWeight = Tensor ( glorotUniform: gateWeightShape)
1364
+ self . updateBias = Tensor ( zeros: gateBiasShape)
1365
+ self . forgetWeight = Tensor ( glorotUniform: gateWeightShape)
1366
+ self . forgetBias = Tensor ( ones: gateBiasShape)
1367
+ self . outputWeight = Tensor ( glorotUniform: gateWeightShape)
1368
+ self . outputBias = Tensor ( zeros: gateBiasShape)
1369
+ }
1370
+
1371
+ public struct State : Differentiable {
1372
+ public var cell : Tensor < Scalar >
1373
+ public var hidden : Tensor < Scalar >
1374
+
1375
+ @differentiable
1376
+ public init ( cell: Tensor < Scalar > , hidden: Tensor < Scalar > ) {
1377
+ self . cell = cell
1378
+ self . hidden = hidden
1379
+ }
1380
+ }
1381
+
1382
+ /// Returns the output obtained from applying the layer to the given input.
1383
+ ///
1384
+ /// - Parameters:
1385
+ /// - input: The input to the layer.
1386
+ /// - context: The contextual information for the layer application, e.g. the current learning
1387
+ /// phase.
1388
+ /// - Returns: The hidden state.
1389
+ @differentiable
1390
+ public func applied( to input: Input ) -> Output {
1391
+ let gateInput = input. input. concatenated ( with: input. state. hidden, alongAxis: 1 )
1392
+
1393
+ let inputGate = sigmoid ( matmul ( gateInput, inputWeight) + inputBias)
1394
+ let updateGate = tanh ( matmul ( gateInput, updateWeight) + updateBias)
1395
+ let forgetGate = sigmoid ( matmul ( gateInput, forgetWeight) + forgetBias)
1396
+ let outputGate = sigmoid ( matmul ( gateInput, outputWeight) + outputBias)
1397
+
1398
+ let newCellState = input. state. cell * forgetGate + inputGate * updateGate
1399
+ let newHiddenState = tanh ( newCellState) * outputGate
1400
+
1401
+ let newState = State ( cell: newCellState, hidden: newHiddenState)
1402
+
1403
+ return Output ( output: newState, state: newState)
1404
+ }
1405
+ }
0 commit comments