@@ -1299,11 +1299,19 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
1299
1299
return TensorShape ( [ 1 , weight. shape [ 1 ] ] )
1300
1300
}
1301
1301
1302
- public var zeroState : Tensor < Scalar > {
1303
- return Tensor ( zeros: stateShape)
1302
+ public var zeroState : State {
1303
+ return State ( Tensor ( zeros: stateShape) )
1304
+ }
1305
+
1306
+ // TODO(TF-507): Revert to `typealias State = Tensor<Scalar>` after
1307
+ // SR-10697 is fixed.
1308
+ public struct State : Differentiable {
1309
+ public let value : Tensor < Scalar >
1310
+ public init ( _ value: Tensor < Scalar > ) {
1311
+ self . value = value
1312
+ }
1304
1313
}
1305
1314
1306
- public typealias State = Tensor < Scalar >
1307
1315
public typealias TimeStepInput = Tensor < Scalar >
1308
1316
public typealias TimeStepOutput = State
1309
1317
public typealias Input = RNNCellInput < TimeStepInput , State >
@@ -1319,8 +1327,7 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
1319
1327
seed: ( Int64 , Int64 ) = ( Int64 . random ( in: Int64 . min..< Int64 . max) ,
1320
1328
Int64 . random ( in: Int64 . min..< Int64 . max) ) ) {
1321
1329
let concatenatedInputSize = inputSize + hiddenSize
1322
- self . weight = Tensor ( glorotUniform: [ concatenatedInputSize, hiddenSize] ,
1323
- seed: seed)
1330
+ self . weight = Tensor ( glorotUniform: [ concatenatedInputSize, hiddenSize] , seed: seed)
1324
1331
self . bias = Tensor ( zeros: [ hiddenSize] )
1325
1332
}
1326
1333
@@ -1330,8 +1337,8 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
1330
1337
/// - Returns: The hidden state.
1331
1338
@differentiable
1332
1339
public func call( _ input: Input ) -> Output {
1333
- let concatenatedInput = input. input. concatenated ( with: input. state, alongAxis: 1 )
1334
- let newState = tanh ( matmul ( concatenatedInput, weight) + bias)
1340
+ let concatenatedInput = input. input. concatenated ( with: input. state. value , alongAxis: 1 )
1341
+ let newState = State ( tanh ( matmul ( concatenatedInput, weight) + bias) )
1335
1342
return Output ( output: newState, state: newState)
1336
1343
}
1337
1344
}
0 commit comments