@@ -412,6 +412,30 @@ final class LayerTests: XCTestCase {
412
412
// [ 0.0, 0.0, 0.0, 0.0]])
413
413
// XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457])
414
414
}
415
+
416
+ func testLSTM( ) {
417
+ let x = Tensor < Float > ( rangeFrom: 0.0 , to: 0.4 , stride: 0.1 ) . rankLifted ( )
418
+ let inputs : [ Tensor < Float > ] = Array ( repeating: x, count: 4 )
419
+ let rnn = RNN ( LSTMCell < Float > ( inputSize: 4 , hiddenSize: 4 ,
420
+ seed: ( 0xFeed , 0xBeef ) ) )
421
+ withTensorLeakChecking {
422
+ let ( outputs, _) = rnn. valueWithPullback ( at: inputs) { rnn, inputs in
423
+ return rnn ( inputs)
424
+ }
425
+
426
+ XCTAssertEqual ( outputs. map { $0. cell } ,
427
+ [ [ [ 0.1147887 , 0.110584 , - 0.064081416 , - 0.08400999 ] ] ,
428
+ [ [ 0.20066944 , 0.20825693 , - 0.11570193 , - 0.14060757 ] ] ,
429
+ [ [ 0.26505938 , 0.29501802 , - 0.15672679 , - 0.1794617 ] ] ,
430
+ [ [ 0.31350702 , 0.37243342 , - 0.1890606 , - 0.20662251 ] ] ] )
431
+
432
+ XCTAssertEqual ( outputs. map { $0. hidden } ,
433
+ [ [ [ 0.06314508 , 0.060653392 , - 0.029783601 , - 0.037988894 ] ] ,
434
+ [ [ 0.10919889 , 0.114127055 , - 0.053144053 , - 0.063461654 ] ] ,
435
+ [ [ 0.14266092 , 0.16068783 , - 0.07122802 , - 0.08071505 ] ] ,
436
+ [ [ 0.16709672 , 0.20094386 , - 0.0851357 , - 0.09258326 ] ] ] )
437
+ }
438
+ }
415
439
416
440
func testFunction( ) {
417
441
let tanhLayer = Function < Tensor < Float > , Tensor < Float > > ( tanh)
@@ -543,6 +567,7 @@ final class LayerTests: XCTestCase {
543
567
( " testRNN " , testRNN) ,
544
568
( " testFunction " , testFunction) ,
545
569
( " testBatchNorm " , testBatchNorm) ,
546
- ( " testLayerNorm " , testLayerNorm)
570
+ ( " testLayerNorm " , testLayerNorm) ,
571
+ ( " testLSTM " , testLSTM)
547
572
]
548
573
}
0 commit comments