Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit f4cd012

Browse files
authored
Add RNNCell protocol. (#80)
1 parent 5fff54b commit f4cd012

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,3 +1314,44 @@ public struct Reshape<Scalar: TensorFlowFloatingPoint>: Layer {
13141314
return input.reshaped(toShape: shape)
13151315
}
13161316
}
1317+
1318+
/// An input to a recurrent neural network.
1319+
public struct RNNInput<TimeStepInput: Differentiable, State: Differentiable>: Differentiable {
1320+
/// The input at the current time step.
1321+
public var timeStepInput: TimeStepInput
1322+
/// The previous state.
1323+
public var previousState: State
1324+
1325+
@differentiable
1326+
public init(timeStepInput: TimeStepInput, previousState: State) {
1327+
self.timeStepInput = timeStepInput
1328+
self.previousState = previousState
1329+
}
1330+
}
1331+
1332+
/// A recurrent neural network cell.
1333+
public protocol RNNCell: Layer where Input == RNNInput<TimeStepInput, State> {
1334+
/// The input at a time step.
1335+
associatedtype TimeStepInput: Differentiable
1336+
/// The state that may be preserved across time steps.
1337+
typealias State = Output
1338+
/// The zero state.
1339+
var zeroState: State { get }
1340+
}
1341+
1342+
public extension RNNCell {
1343+
/// Returns the new state obtained from applying the RNN cell to the input at the current time
1344+
/// step and the previous state.
1345+
///
1346+
/// - Parameters:
1347+
/// - timeStepInput: The input at the current time step.
1348+
/// - previousState: The previous state of the RNN cell.
1349+
/// - context: The contextual information for the layer application, e.g. the current learning
1350+
/// phase.
1351+
/// - Returns: The output.
1352+
@differentiable
1353+
func applied(to timeStepInput: TimeStepInput, previous: State, in context: Context) -> State {
1354+
return applied(to: Input(timeStepInput: timeStepInput, previousState: previous),
1355+
in: context)
1356+
}
1357+
}

0 commit comments

Comments
 (0)