@@ -412,6 +412,27 @@ final class LayerTests: XCTestCase {
412
412
// [ 0.0, 0.0, 0.0, 0.0]])
413
413
// XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457])
414
414
}
415
+
416
+ func testLSTM( ) {
417
+ let x = Tensor < Float > ( rangeFrom: 0.0 , to: 0.4 , stride: 0.1 ) . rankLifted ( )
418
+ let inputs : [ Tensor < Float > ] = Array ( repeating: x, count: 4 )
419
+ let rnn = RNN ( LSTMCell < Float > ( inputSize: 4 , hiddenSize: 4 , seed: ( 0xFeed , 0xBeef ) ) )
420
+ withTensorLeakChecking {
421
+ let ( outputs, _) = rnn. valueWithPullback ( at: inputs) { rnn, inputs in
422
+ return rnn ( inputs)
423
+ }
424
+ XCTAssertEqual ( outputs. map { $0. cell } ,
425
+ [ [ [ 0.1147887 , 0.110584 , - 0.064081416 , - 0.08400999 ] ] ,
426
+ [ [ 0.20066944 , 0.20825693 , - 0.11570193 , - 0.14060757 ] ] ,
427
+ [ [ 0.26505938 , 0.29501802 , - 0.15672679 , - 0.1794617 ] ] ,
428
+ [ [ 0.31350702 , 0.37243342 , - 0.1890606 , - 0.20662251 ] ] ] )
429
+ XCTAssertEqual ( outputs. map { $0. hidden } ,
430
+ [ [ [ 0.06314508 , 0.060653392 , - 0.029783601 , - 0.037988894 ] ] ,
431
+ [ [ 0.10919889 , 0.114127055 , - 0.053144053 , - 0.063461654 ] ] ,
432
+ [ [ 0.14266092 , 0.16068783 , - 0.07122802 , - 0.08071505 ] ] ,
433
+ [ [ 0.16709672 , 0.20094386 , - 0.0851357 , - 0.09258326 ] ] ] )
434
+ }
435
+ }
415
436
416
437
func testFunction( ) {
417
438
let tanhLayer = Function < Tensor < Float > , Tensor < Float > > ( tanh)
@@ -541,6 +562,7 @@ final class LayerTests: XCTestCase {
541
562
( " testSimpleRNNCell " , testSimpleRNNCell) ,
542
563
( " testDense " , testDense) ,
543
564
( " testRNN " , testRNN) ,
565
+ ( " testLSTM " , testLSTM) ,
544
566
( " testFunction " , testFunction) ,
545
567
( " testBatchNorm " , testBatchNorm) ,
546
568
( " testLayerNorm " , testLayerNorm)
0 commit comments