Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit a141a7d

Browse files
committed
added test for LSTMCell
1 parent c66cf59 commit a141a7d

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,30 @@ final class LayerTests: XCTestCase {
412412
// [ 0.0, 0.0, 0.0, 0.0]])
413413
// XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457])
414414
}
415+
416+
func testLSTM() {
417+
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
418+
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
419+
let rnn = RNN(LSTMCell<Float>(inputSize: 4, hiddenSize: 4,
420+
seed: (0xFeed, 0xBeef)))
421+
withTensorLeakChecking {
422+
let (outputs, _) = rnn.valueWithPullback(at: inputs) { rnn, inputs in
423+
return rnn(inputs)
424+
}
425+
426+
XCTAssertEqual(outputs.map { $0.cell },
427+
[[[ 0.1147887, 0.110584, -0.064081416, -0.08400999]],
428+
[[ 0.20066944, 0.20825693, -0.11570193, -0.14060757]],
429+
[[ 0.26505938, 0.29501802, -0.15672679, -0.1794617]],
430+
[[ 0.31350702, 0.37243342, -0.1890606, -0.20662251]]])
431+
432+
XCTAssertEqual(outputs.map { $0.hidden },
433+
[[[ 0.06314508, 0.060653392, -0.029783601, -0.037988894]],
434+
[[ 0.10919889, 0.114127055, -0.053144053, -0.063461654]],
435+
[[ 0.14266092, 0.16068783, -0.07122802, -0.08071505]],
436+
[[ 0.16709672, 0.20094386, -0.0851357, -0.09258326]]])
437+
}
438+
}
415439

416440
func testFunction() {
417441
let tanhLayer = Function<Tensor<Float>, Tensor<Float>>(tanh)
@@ -543,6 +567,7 @@ final class LayerTests: XCTestCase {
543567
("testRNN", testRNN),
544568
("testFunction", testFunction),
545569
("testBatchNorm", testBatchNorm),
546-
("testLayerNorm", testLayerNorm)
570+
("testLayerNorm", testLayerNorm),
571+
("testLSTM", testLSTM)
547572
]
548573
}

0 commit comments

Comments
 (0)