Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Fix derivative of RNN.callAsFunction(_:initialState:). #660

Merged
merged 1 commit into from
Feb 4, 2020

Conversation

dan-zheng
Copy link
Member

Previously, RNN._vjpcallAsFunction(_:initialState:) incorrectly used a zero
initial state. Now, it uses initialState as the initial state.

Add RNN gradient tests for SimpleRNNCell, LSTMCell, and GRUCell.
Todo: verify that gradients are correct using a reference implementation.

Previously, `RNN._vjpcallAsFunction(_:initialState:)` incorrectly used a zero
initial state. Now, it uses `initialState` as the initial state.

Add `RNN` gradient tests for `SimpleRNNCell`, `LSTMCell`, and `GRUCell`.
Todo: verify that gradients are correct using a reference implementation.
@dan-zheng dan-zheng requested a review from marcrasi February 3, 2020 23:32
@@ -1190,22 +1189,62 @@ final class LayerTests: XCTestCase {
[ 0.074910110, 0.021107012, -0.049724963, -0.069670826],
[ 0.078670055, 0.022462710, -0.051899005, -0.075331904]],
accuracy: 1e-6)
let (𝛁lstm, _) = pullback(.init(inputs.map { LSTMCell<Float>.State(cell: $0, hidden: $0) }))
// TODO: Verify that LSTM gradients are correct using a reference implementation.
Copy link
Contributor

@marcrasi marcrasi Feb 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm interested in doing this. Should we do something like check some TF python code into this repository that reproduces the gradients?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good!

@dan-zheng dan-zheng merged commit 94ab4f5 into tensorflow:master Feb 4, 2020
@dan-zheng dan-zheng deleted the fix-rnn-derivative branch February 4, 2020 00:00
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants