Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add RNN wrapper for Cells #105

Merged
merged 6 commits into from
Apr 20, 2019
Merged

Add RNN wrapper for Cells #105

merged 6 commits into from
Apr 20, 2019

Conversation

rxwei
Copy link
Contributor

@rxwei rxwei commented Apr 20, 2019

Based off of #100 by @tanmayb123. My push to that branch accidentally closed that PR, so I'm starting a new one.

  • Add generic structure type RNN<Cell: RNNCell> that forms a recurrent layer from a cell. Thanks @tanmayb123 for starting this!
  • Define a custom efficient derivative for RNN.call(_:).
  • Make RNN structures (SimpleRNNCell, LSTMCell, RNN) conditionally conform to VectorNumeric and VectorNumeric for easier integration with more efficient optimizers in the future.
  • Make RNN cell initializers take a random seed, passing them through Tensor.init(glorotUniform:seed:).
  • Add a test for RNN using SimpleRNNCell.

Related to #52. Resolves #91.

@rxwei rxwei added the enhancement New feature or request label Apr 20, 2019
@rxwei rxwei requested review from dan-zheng and jekbradbury April 20, 2019 05:12
@googlebot
Copy link

So there's good news and bad news.

👍 The good news is that everyone that needs to sign a CLA (the pull request submitter and all commit authors) have done so. Everything is all good there.

😕 The bad news is that it appears that one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that here in the pull request.

Note to project maintainer: This is a terminal state, meaning the cla/google commit status will not change from this state. It's up to you to confirm consent of all the commit author(s), set the cla label to yes (if enabled on your project), and then merge this pull request when appropriate.

ℹ️ Googlers: Go here for more info.

@rxwei rxwei mentioned this pull request Apr 20, 2019
@rxwei
Copy link
Contributor Author

rxwei commented Apr 20, 2019

For the record, CLA passed in #100.

𝛁state = 𝛁input.state
reversed𝛁inputs.append(𝛁input.input)
}
return (.init(cell: 𝛁cell), .init(Array(reversed𝛁inputs.reversed())))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Regarding this array reversal, while I could've zero-initialized an array of timeStepCount tensors and modified them in reverse order, it would be less efficient because of the cost of heap-allocating timeStepCount extra tensors.

@tanmayb123
Copy link
Contributor

I think CI isn't working.

@rxwei
Copy link
Contributor Author

rxwei commented Apr 20, 2019

Yup, still investigating. Locally all tests are passing.

@rxwei rxwei merged commit 3f71684 into tensorflow:master Apr 20, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Wrapper to enable recurrent cells to handle sequential input
3 participants