@@ -1283,123 +1283,3 @@ public extension RNNCell {
1283
1283
return applied ( to: RNNCellInput ( input: input, state: state) )
1284
1284
}
1285
1285
}
1286
-
1287
- /// A Simple RNN Cell.
1288
- public struct SimpleRNNCell < Scalar: TensorFlowFloatingPoint > : RNNCell {
1289
- public var weight : Tensor < Scalar >
1290
- public var bias : Tensor < Scalar >
1291
-
1292
- @noDerivative public var stateShape : TensorShape {
1293
- return TensorShape ( [ 1 , weight. shape [ 1 ] ] )
1294
- }
1295
-
1296
- @differentiable
1297
- public var zeroState : Tensor < Scalar > {
1298
- return Tensor ( zeros: stateShape)
1299
- }
1300
-
1301
- public typealias State = Tensor < Scalar >
1302
- public typealias TimeStepInput = Tensor < Scalar >
1303
- public typealias TimeStepOutput = State
1304
- public typealias Input = RNNCellInput < TimeStepInput , State >
1305
- public typealias Output = RNNCellOutput < TimeStepOutput , State >
1306
-
1307
- /// Creates a `SimpleRNNCell` with the specified input size and hidden state size.
1308
- ///
1309
- /// - Parameters:
1310
- /// - inputSize: The number of features in 2-D input tensors.
1311
- /// - hiddenSize: The number of features in 2-D hidden states.
1312
- public init ( inputSize: Int , hiddenSize: Int ) {
1313
- let concatenatedInputSize = inputSize + hiddenSize
1314
- self . weight = Tensor ( glorotUniform: [ concatenatedInputSize, hiddenSize] )
1315
- self . bias = Tensor ( zeros: [ hiddenSize] )
1316
- }
1317
-
1318
- /// Returns the output obtained from applying the layer to the given input.
1319
- ///
1320
- /// - Parameters:
1321
- /// - input: The input to the layer.
1322
- /// - context: The contextual information for the layer application, e.g. the current learning
1323
- /// phase.
1324
- /// - Returns: The hidden state.
1325
- @differentiable
1326
- public func applied( to input: Input ) -> Output {
1327
- let concatenatedInput = input. input. concatenated ( with: input. state)
1328
- let newState = matmul ( concatenatedInput, weight) + bias
1329
- return Output ( output: newState, state: newState)
1330
- }
1331
- }
1332
-
1333
- /// An LSTM Cell.
1334
- public struct LSTMCell < Scalar: TensorFlowFloatingPoint > : RNNCell {
1335
- public var inputWeight , updateWeight , forgetWeight , outputWeight : Tensor < Scalar >
1336
- public var inputBias , updateBias , forgetBias , outputBias : Tensor < Scalar >
1337
-
1338
- @noDerivative public var stateShape : TensorShape {
1339
- return TensorShape ( [ 1 , inputWeight. shape [ 1 ] ] )
1340
- }
1341
-
1342
- @differentiable
1343
- public var zeroState : State {
1344
- return State ( cell: Tensor ( zeros: stateShape) , hidden: Tensor ( zeros: stateShape) )
1345
- }
1346
-
1347
- public typealias TimeStepInput = Tensor < Scalar >
1348
- public typealias TimeStepOutput = State
1349
- public typealias Input = RNNCellInput < TimeStepInput , State >
1350
- public typealias Output = RNNCellOutput < TimeStepOutput , State >
1351
-
1352
- /// Creates a `LSTMCell` with the specified input size and hidden state size.
1353
- ///
1354
- /// - Parameters:
1355
- /// - inputSize: The number of features in 2-D input tensors.
1356
- /// - hiddenSize: The number of features in 2-D hidden states.
1357
- public init ( inputSize: Int , hiddenSize: Int ) {
1358
- let concatenatedInputSize = inputSize + hiddenSize
1359
- let gateWeightShape = TensorShape ( [ concatenatedInputSize, hiddenSize] )
1360
- let gateBiasShape = TensorShape ( [ hiddenSize] )
1361
- self . inputWeight = Tensor ( glorotUniform: gateWeightShape)
1362
- self . inputBias = Tensor ( zeros: gateBiasShape)
1363
- self . updateWeight = Tensor ( glorotUniform: gateWeightShape)
1364
- self . updateBias = Tensor ( zeros: gateBiasShape)
1365
- self . forgetWeight = Tensor ( glorotUniform: gateWeightShape)
1366
- self . forgetBias = Tensor ( ones: gateBiasShape)
1367
- self . outputWeight = Tensor ( glorotUniform: gateWeightShape)
1368
- self . outputBias = Tensor ( zeros: gateBiasShape)
1369
- }
1370
-
1371
- public struct State : Differentiable {
1372
- public var cell : Tensor < Scalar >
1373
- public var hidden : Tensor < Scalar >
1374
-
1375
- @differentiable
1376
- public init ( cell: Tensor < Scalar > , hidden: Tensor < Scalar > ) {
1377
- self . cell = cell
1378
- self . hidden = hidden
1379
- }
1380
- }
1381
-
1382
- /// Returns the output obtained from applying the layer to the given input.
1383
- ///
1384
- /// - Parameters:
1385
- /// - input: The input to the layer.
1386
- /// - context: The contextual information for the layer application, e.g. the current learning
1387
- /// phase.
1388
- /// - Returns: The hidden state.
1389
- @differentiable
1390
- public func applied( to input: Input ) -> Output {
1391
- let gateInput = input. input. concatenated ( with: input. state. hidden, alongAxis: 1 )
1392
-
1393
- let inputGate = sigmoid ( matmul ( gateInput, inputWeight) + inputBias)
1394
- let updateGate = tanh ( matmul ( gateInput, updateWeight) + updateBias)
1395
- let forgetGate = sigmoid ( matmul ( gateInput, forgetWeight) + forgetBias)
1396
- let outputGate = sigmoid ( matmul ( gateInput, outputWeight) + outputBias)
1397
-
1398
- let newCellState = input. state. cell * forgetGate + inputGate * updateGate
1399
- let newHiddenState = tanh ( newCellState) * outputGate
1400
-
1401
- let newState = State ( cell: newCellState, hidden: newHiddenState)
1402
-
1403
- return Output ( output: newState, state: newState)
1404
- }
1405
- }
0 commit comments