@@ -1314,3 +1314,44 @@ public struct Reshape<Scalar: TensorFlowFloatingPoint>: Layer {
1314
1314
return input. reshaped ( toShape: shape)
1315
1315
}
1316
1316
}
1317
+
1318
+ /// An input to a recurrent neural network.
1319
+ public struct RNNInput < TimeStepInput: Differentiable , State: Differentiable > : Differentiable {
1320
+ /// The input at the current time step.
1321
+ public var timeStepInput : TimeStepInput
1322
+ /// The previous state.
1323
+ public var previousState : State
1324
+
1325
+ @differentiable
1326
+ public init ( timeStepInput: TimeStepInput , previousState: State ) {
1327
+ self . timeStepInput = timeStepInput
1328
+ self . previousState = previousState
1329
+ }
1330
+ }
1331
+
1332
+ /// A recurrent neural network cell.
1333
+ public protocol RNNCell : Layer where Input == RNNInput < TimeStepInput , State > {
1334
+ /// The input at a time step.
1335
+ associatedtype TimeStepInput : Differentiable
1336
+ /// The state that may be preserved across time steps.
1337
+ typealias State = Output
1338
+ /// The zero state.
1339
+ var zeroState : State { get }
1340
+ }
1341
+
1342
+ public extension RNNCell {
1343
+ /// Returns the new state obtained from applying the RNN cell to the input at the current time
1344
+ /// step and the previous state.
1345
+ ///
1346
+ /// - Parameters:
1347
+ /// - timeStepInput: The input at the current time step.
1348
+ /// - previousState: The previous state of the RNN cell.
1349
+ /// - context: The contextual information for the layer application, e.g. the current learning
1350
+ /// phase.
1351
+ /// - Returns: The output.
1352
+ @differentiable
1353
+ func applied( to timeStepInput: TimeStepInput , previous: State , in context: Context ) -> State {
1354
+ return applied ( to: Input ( timeStepInput: timeStepInput, previousState: previous) ,
1355
+ in: context)
1356
+ }
1357
+ }
0 commit comments