Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

added test for LSTMCell #415

Merged
merged 2 commits into from
Aug 5, 2019
Merged

added test for LSTMCell #415

merged 2 commits into from
Aug 5, 2019

Conversation

dhasl002
Copy link
Contributor

@dhasl002 dhasl002 commented Aug 4, 2019

  • added unit test for LSTM cell, this feature was previously untested

@dhasl002 dhasl002 closed this Aug 4, 2019
@dhasl002 dhasl002 reopened this Aug 4, 2019
let (outputs, _) = rnn.valueWithPullback(at: inputs) { rnn, inputs in
return rnn(inputs)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant empty line.

[[ 0.20066944, 0.20825693, -0.11570193, -0.14060757]],
[[ 0.26505938, 0.29501802, -0.15672679, -0.1794617]],
[[ 0.31350702, 0.37243342, -0.1890606, -0.20662251]]])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant empty line.

let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
let rnn = RNN(LSTMCell<Float>(inputSize: 4, hiddenSize: 4,
seed: (0xFeed, 0xBeef)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to the end of the previous line.

@@ -543,6 +567,7 @@ final class LayerTests: XCTestCase {
("testRNN", testRNN),
("testFunction", testFunction),
("testBatchNorm", testBatchNorm),
("testLayerNorm", testLayerNorm)
("testLayerNorm", testLayerNorm),
("testLSTM", testLSTM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Place this between "testRNN" and "testFunction" like how the methods are ordered.

@dhasl002
Copy link
Contributor Author

dhasl002 commented Aug 5, 2019

@rxwei comments addressed in 38bbd20

@rxwei rxwei merged commit 22f5e43 into tensorflow:master Aug 5, 2019
@dhasl002 dhasl002 deleted the lstm_test branch August 6, 2019 02:21
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants