@@ -1281,7 +1281,7 @@ public extension RNNCell {
1281
1281
}
1282
1282
1283
1283
/// A Simple RNN Cell.
1284
- public struct SimpleRNNCell < Scalar: TensorFlowFloatingPoint > : RNNCell {
1284
+ public struct SimpleRNNCell < Scalar: TensorFlowFloatingPoint > : RNNCell , VectorNumeric {
1285
1285
public var weight : Tensor < Scalar >
1286
1286
public var bias : Tensor < Scalar >
1287
1287
@@ -1304,9 +1304,13 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1304
1304
/// - Parameters:
1305
1305
/// - inputSize: The number of features in 2-D input tensors.
1306
1306
/// - hiddenSize: The number of features in 2-D hidden states.
1307
- public init ( inputSize: Int , hiddenSize: Int ) {
1307
+ /// - seed: The random seed for initialization. The default value is random.
1308
+ public init ( inputSize: Int , hiddenSize: Int ,
1309
+ seed: ( Int64 , Int64 ) = ( Int64 . random ( in: Int64 . min..< Int64 . max) ,
1310
+ Int64 . random ( in: Int64 . min..< Int64 . max) ) ) {
1308
1311
let concatenatedInputSize = inputSize + hiddenSize
1309
- self . weight = Tensor ( glorotUniform: [ concatenatedInputSize, hiddenSize] )
1312
+ self . weight = Tensor ( glorotUniform: [ concatenatedInputSize, hiddenSize] ,
1313
+ seed: seed)
1310
1314
self . bias = Tensor ( zeros: [ hiddenSize] )
1311
1315
}
1312
1316
@@ -1326,7 +1330,7 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1326
1330
}
1327
1331
1328
1332
/// An LSTM Cell.
1329
- public struct LSTMCell < Scalar: TensorFlowFloatingPoint > : RNNCell {
1333
+ public struct LSTMCell < Scalar: TensorFlowFloatingPoint > : RNNCell , VectorNumeric {
1330
1334
public var inputWeight , updateWeight , forgetWeight , outputWeight : Tensor < Scalar >
1331
1335
public var inputBias , updateBias , forgetBias , outputBias : Tensor < Scalar >
1332
1336
@@ -1348,17 +1352,19 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1348
1352
/// - Parameters:
1349
1353
/// - inputSize: The number of features in 2-D input tensors.
1350
1354
/// - hiddenSize: The number of features in 2-D hidden states.
1351
- public init ( inputSize: Int , hiddenSize: Int ) {
1355
+ public init ( inputSize: Int , hiddenSize: Int ,
1356
+ seed: ( Int64 , Int64 ) = ( Int64 . random ( in: Int64 . min..< Int64 . max) ,
1357
+ Int64 . random ( in: Int64 . min..< Int64 . max) ) ) {
1352
1358
let concatenatedInputSize = inputSize + hiddenSize
1353
1359
let gateWeightShape = TensorShape ( [ concatenatedInputSize, hiddenSize] )
1354
1360
let gateBiasShape = TensorShape ( [ hiddenSize] )
1355
- self . inputWeight = Tensor ( glorotUniform: gateWeightShape)
1361
+ self . inputWeight = Tensor ( glorotUniform: gateWeightShape, seed : seed )
1356
1362
self . inputBias = Tensor ( zeros: gateBiasShape)
1357
- self . updateWeight = Tensor ( glorotUniform: gateWeightShape)
1363
+ self . updateWeight = Tensor ( glorotUniform: gateWeightShape, seed : seed )
1358
1364
self . updateBias = Tensor ( zeros: gateBiasShape)
1359
- self . forgetWeight = Tensor ( glorotUniform: gateWeightShape)
1365
+ self . forgetWeight = Tensor ( glorotUniform: gateWeightShape, seed : seed )
1360
1366
self . forgetBias = Tensor ( ones: gateBiasShape)
1361
- self . outputWeight = Tensor ( glorotUniform: gateWeightShape)
1367
+ self . outputWeight = Tensor ( glorotUniform: gateWeightShape, seed : seed )
1362
1368
self . outputBias = Tensor ( zeros: gateBiasShape)
1363
1369
}
1364
1370
@@ -1397,3 +1403,91 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1397
1403
return Output ( output: newState, state: newState)
1398
1404
}
1399
1405
}
1406
+
1407
+ public struct RNN< Cell: RNNCell > : Layer {
1408
+ public typealias Input = [ Cell . TimeStepInput ]
1409
+ public typealias Output = [ Cell . TimeStepOutput ]
1410
+
1411
+ public var cell : Cell
1412
+
1413
+ public init ( _ cell: @autoclosure ( ) -> Cell ) {
1414
+ self . cell = cell ( )
1415
+ }
1416
+
1417
+ @differentiable ( wrt: ( self , input) , vjp: _vjpCall ( _: initialState: ) )
1418
+ public func call( _ input: [ Cell . TimeStepInput ] ,
1419
+ initialState: Cell . State ) -> [ Cell . TimeStepOutput ] {
1420
+ var currentHiddenState = initialState
1421
+ var timeStepOutputs : [ Cell . TimeStepOutput ] = [ ]
1422
+ for timestep in input {
1423
+ let output = cell ( input: timestep, state: currentHiddenState)
1424
+ currentHiddenState = output. state
1425
+ timeStepOutputs. append ( output. output)
1426
+ }
1427
+ return timeStepOutputs
1428
+ }
1429
+
1430
+ @usableFromInline
1431
+ internal func _vjpCall(
1432
+ _ inputs: [ Cell . TimeStepInput ] , initialState: Cell . State
1433
+ ) -> ( [ Cell . TimeStepOutput ] ,
1434
+ ( Array < Cell . TimeStepOutput > . CotangentVector )
1435
+ -> ( CotangentVector , Array < Cell . TimeStepInput > . CotangentVector ) ) {
1436
+ let timeStepCount = inputs. count
1437
+ var currentHiddenState = cell. zeroState
1438
+ var timeStepOutputs : [ Cell . TimeStepOutput ] = [ ]
1439
+ timeStepOutputs. reserveCapacity ( timeStepCount)
1440
+ var backpropagators : [ Cell . Backpropagator ] = [ ]
1441
+ backpropagators. reserveCapacity ( timeStepCount)
1442
+ for timestep in inputs {
1443
+ let ( output, backpropagator) =
1444
+ cell. appliedForBackpropagation ( to: . init( input: timestep,
1445
+ state: currentHiddenState) )
1446
+ currentHiddenState = output. state
1447
+ timeStepOutputs. append ( output. output)
1448
+ backpropagators. append ( backpropagator)
1449
+ }
1450
+ return ( timeStepOutputs, { 𝛁outputs in
1451
+ precondition ( 𝛁outputs. base. count == timeStepCount,
1452
+ " The number of output gradients must equal the number of time steps " )
1453
+ var 𝛁cell = Cell. CotangentVector. zero
1454
+ var 𝛁state = Cell. State. CotangentVector. zero
1455
+ var reversed 𝛁inputs: [ Cell . TimeStepInput . CotangentVector ] = [ ]
1456
+ reversed 𝛁inputs. reserveCapacity ( timeStepCount)
1457
+ for (𝛁output, backpropagator) in zip( 𝛁outputs. base, backpropagators) . reversed( ) {
1458
+ let ( new𝛁cell, 𝛁in put) = backpropagator ( . init( output: 𝛁output, state: 𝛁state) )
1459
+ 𝛁cell = new𝛁cell
1460
+ 𝛁state = 𝛁input. state
1461
+ reversed𝛁inputs. append( 𝛁input. input)
1462
+ }
1463
+ return ( . init( cell: 𝛁cell) , . init( Array ( reversed 𝛁inputs. reversed ( ) ) ) )
1464
+ } )
1465
+ }
1466
+
1467
+ @differentiable( wrt: ( self , inputs) )
1468
+ public func call( _ inputs: [ Cell . TimeStepInput] ) - > [ Cell . TimeStepOutput] {
1469
+ return self ( inputs, initialState: cell. zeroState. withoutDerivative ( ) )
1470
+ }
1471
+
1472
+ /* TODO: Uncomment once control flow and differentiation through force unwrapping is supported.
1473
+ @differentiable(wrt: (self, inputs))
1474
+ public func lastOutput(from inputs: [Cell.TimeStepInput],
1475
+ initialState: Cell.State) -> Cell.TimeStepOutput {
1476
+ precondition(!inputs.isEmpty, "inputs cannot be empty")
1477
+ return self(inputs, initialState: initialState).last!
1478
+ }
1479
+
1480
+ @differentiable(wrt: (self, inputs))
1481
+ public func lastOutput(from inputs: [Cell.TimeStepInput]) -> Cell.TimeStepOutput {
1482
+ precondition(!inputs.isEmpty, "inputs cannot be empty")
1483
+ return self(inputs, initialState: cell.zeroState).last!
1484
+ }
1485
+ */
1486
+ }
1487
+
1488
+ extension RNN: Equatable where Cell: Equatable { }
1489
+ extension RNN : AdditiveArithmetic where Cell: AdditiveArithmetic { }
1490
+ extension RNN : VectorNumeric where Cell: VectorNumeric { }
1491
+
1492
+ public typealias SimpleRNN < Scalar: TensorFlowFloatingPoint > = RNN < SimpleRNNCell < Scalar > >
1493
+ public typealias LSTM < Scalar: TensorFlowFloatingPoint > = RNN < LSTMCell < Scalar > >
0 commit comments