-
Notifications
You must be signed in to change notification settings - Fork 137
Conversation
} | ||
|
||
/// A recurrent neural network cell. | ||
public protocol RNNCell: Layer where Input == RNNInput<TimeStepInput, State> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm intentionally not adding an init
. At this point, I have not reached a conclusion about whether init
is generalizable across all different RNN cells.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! You explained the design to me in person and it generally makes sense.
One comment is that, though the protocol requirements are individually documented, it's not clear how they interact or why some requirements exist (e.g. zeroState
). It might be nice to include some extra type-level doc comments elaborating on how things fit together.
(I guess this will be ameliorated when concrete RNN cell implementations conforming to RNNCell
and using RNNInput
are added.)
A goal might be to enable a Python RNN developer to quickly understand the Swift RNN cell design after reading doc comments.
Actually, in response to myself: I thought about the requirements individually for a few seconds and everything clicked. I think the only truly weird RNN requirement for Python developers is But I think the goal above is still valuable. Some people may need more documentation to understand things. |
I appreciate your thoroughness and you are absolutely right, and I'm well aware of the quality of API docs I'd like all of us to go for. However, I'd like to emphasize that it's just not the right time yet. As mentioned in the PR description, it's something to unblock #71 and we don't know if it's general enough yet. For instance, when no conforming types exist, one can't even document usage examples. What I'd like to see is for all RNN cells to land first, and then complete high-quality documentation and tutorials based on concrete use cases. |
Makes sense! I think we're on the same page. |
This PR adds
RNNCell
that intends to generalize simple RNNs, LSTMs, GRUs, and NTMs. This will enable generic algorithms or layers that can be parameterized by any RNN.See #71 for context.
cc @tanmayb123 @eaplatanios