Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 861d1f5

Browse files
authored
Improve RNN cell abstraction. (#86)
The existing `RNNCell` protocol and `RNNInput` type are not flexible in that each time step has the take both the previous output and the hidden state. This PR lifts that restriction. * Rename `RNNInput` to `RNNCellInput` so that it's more accurate. * Add a new `RNNCellOutput` generic structure type that stores an output and a state. * Add associated type `State` in `RNNCell`. * Make the `Output` type of `RNNCell` be `RNNOutput<TimeStepOutput, State>`. Thanks @superbobry for the suggestions.
1 parent 1944e47 commit 861d1f5

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,26 +1318,44 @@ public struct Reshape<Scalar: TensorFlowFloatingPoint>: Layer {
13181318
}
13191319

13201320
/// An input to a recurrent neural network.
1321-
public struct RNNInput<TimeStepInput: Differentiable, State: Differentiable>: Differentiable {
1321+
public struct RNNCellInput<Input: Differentiable, State: Differentiable>: Differentiable {
13221322
/// The input at the current time step.
1323-
public var timeStepInput: TimeStepInput
1323+
public var input: Input
13241324
/// The previous state.
1325-
public var previousState: State
1325+
public var state: State
13261326

13271327
@differentiable
1328-
public init(timeStepInput: TimeStepInput, previousState: State) {
1329-
self.timeStepInput = timeStepInput
1330-
self.previousState = previousState
1328+
public init(input: Input, state: State) {
1329+
self.input = input
1330+
self.state = state
1331+
}
1332+
}
1333+
1334+
/// An output to a recurrent neural network.
1335+
public struct RNNCellOutput<Output: Differentiable, State: Differentiable>: Differentiable {
1336+
/// The output at the current time step.
1337+
public var output: Output
1338+
/// The current state.
1339+
public var state: State
1340+
1341+
@differentiable
1342+
public init(output: Output, state: State) {
1343+
self.output = output
1344+
self.state = state
13311345
}
13321346
}
13331347

13341348
/// A recurrent neural network cell.
1335-
public protocol RNNCell: Layer where Input == RNNInput<TimeStepInput, State> {
1349+
public protocol RNNCell: Layer where Input == RNNCellInput<TimeStepInput, State>,
1350+
Output == RNNCellOutput<TimeStepOutput, State> {
13361351
/// The input at a time step.
13371352
associatedtype TimeStepInput: Differentiable
1353+
/// The output at a time step.
1354+
associatedtype TimeStepOutput: Differentiable
13381355
/// The state that may be preserved across time steps.
1339-
typealias State = Output
1356+
associatedtype State: Differentiable
13401357
/// The zero state.
1358+
@differentiable
13411359
var zeroState: State { get }
13421360
}
13431361

@@ -1352,8 +1370,11 @@ public extension RNNCell {
13521370
/// phase.
13531371
/// - Returns: The output.
13541372
@differentiable
1355-
func applied(to timeStepInput: TimeStepInput, previous: State, in context: Context) -> State {
1356-
return applied(to: Input(timeStepInput: timeStepInput, previousState: previous),
1357-
in: context)
1373+
func applied(
1374+
to input: TimeStepInput,
1375+
state: State,
1376+
in context: Context
1377+
) -> RNNCellOutput<TimeStepOutput, State> {
1378+
return applied(to: RNNCellInput(input: input, state: state), in: context)
13581379
}
13591380
}

0 commit comments

Comments
 (0)