Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 22f5e43

Browse files
dhasl002rxwei
authored andcommitted
added test for LSTMCell (#415)
added unit test for LSTM cell, this feature was previously untested
1 parent 0ef44fa commit 22f5e43

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,27 @@ final class LayerTests: XCTestCase {
412412
// [ 0.0, 0.0, 0.0, 0.0]])
413413
// XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457])
414414
}
415+
416+
func testLSTM() {
417+
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
418+
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
419+
let rnn = RNN(LSTMCell<Float>(inputSize: 4, hiddenSize: 4, seed: (0xFeed, 0xBeef)))
420+
withTensorLeakChecking {
421+
let (outputs, _) = rnn.valueWithPullback(at: inputs) { rnn, inputs in
422+
return rnn(inputs)
423+
}
424+
XCTAssertEqual(outputs.map { $0.cell },
425+
[[[ 0.1147887, 0.110584, -0.064081416, -0.08400999]],
426+
[[ 0.20066944, 0.20825693, -0.11570193, -0.14060757]],
427+
[[ 0.26505938, 0.29501802, -0.15672679, -0.1794617]],
428+
[[ 0.31350702, 0.37243342, -0.1890606, -0.20662251]]])
429+
XCTAssertEqual(outputs.map { $0.hidden },
430+
[[[ 0.06314508, 0.060653392, -0.029783601, -0.037988894]],
431+
[[ 0.10919889, 0.114127055, -0.053144053, -0.063461654]],
432+
[[ 0.14266092, 0.16068783, -0.07122802, -0.08071505]],
433+
[[ 0.16709672, 0.20094386, -0.0851357, -0.09258326]]])
434+
}
435+
}
415436

416437
func testFunction() {
417438
let tanhLayer = Function<Tensor<Float>, Tensor<Float>>(tanh)
@@ -541,6 +562,7 @@ final class LayerTests: XCTestCase {
541562
("testSimpleRNNCell", testSimpleRNNCell),
542563
("testDense", testDense),
543564
("testRNN", testRNN),
565+
("testLSTM", testLSTM),
544566
("testFunction", testFunction),
545567
("testBatchNorm", testBatchNorm),
546568
("testLayerNorm", testLayerNorm)

0 commit comments

Comments
 (0)