Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add RNNCell protocol. #80

Merged
merged 4 commits into from
Apr 7, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions Sources/DeepLearning/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1314,3 +1314,44 @@ public struct Reshape<Scalar: TensorFlowFloatingPoint>: Layer {
return input.reshaped(toShape: shape)
}
}

/// An input to a recurrent neural network.
public struct RNNInput<TimeStepInput: Differentiable, State: Differentiable>: Differentiable {
/// The input at the current time step.
public var timeStepInput: TimeStepInput
/// The previous state.
public var previousState: State

@differentiable
public init(timeStepInput: TimeStepInput, previousState: State) {
self.timeStepInput = timeStepInput
self.previousState = previousState
}
}

/// A recurrent neural network cell.
public protocol RNNCell: Layer where Input == RNNInput<TimeStepInput, State> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm intentionally not adding an init. At this point, I have not reached a conclusion about whether init is generalizable across all different RNN cells.

/// The input at a time step.
associatedtype TimeStepInput: Differentiable
/// The state that may be preserved across time steps.
typealias State = Output
/// The zero state.
var zeroState: State { get }
}

public extension RNNCell {
/// Returns the new state obtained from applying the RNN cell to the input at the current time
/// step and the previous state.
///
/// - Parameters:
/// - timeStepInput: The input at the current time step.
/// - previousState: The previous state of the RNN cell.
/// - context: The contextual information for the layer application, e.g. the current learning
/// phase.
/// - Returns: The output.
@differentiable
func applied(to timeStepInput: TimeStepInput, previous: State, in context: Context) -> State {
return applied(to: Input(timeStepInput: timeStepInput, previousState: previous),
in: context)
}
}