@@ -60,8 +60,9 @@ public protocol RNNCell: Layer
60
60
associatedtype TimeStepOutput : Differentiable
61
61
/// The state that may be preserved across time steps.
62
62
associatedtype State : Differentiable
63
- /// The zero state.
64
- var zeroState : State { get }
63
+
64
+ /// Returns a zero-valued state with shape compatible with the provided input.
65
+ func zeroState( for input: TimeStepInput ) -> State
65
66
}
66
67
67
68
public extension RNNCell {
@@ -91,14 +92,6 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
91
92
public var weight : Tensor < Scalar >
92
93
public var bias : Tensor < Scalar >
93
94
94
- @noDerivative public var stateShape : TensorShape {
95
- TensorShape ( [ 1 , weight. shape [ 1 ] ] )
96
- }
97
-
98
- public var zeroState : State {
99
- State ( Tensor ( zeros: stateShape) )
100
- }
101
-
102
95
// TODO(TF-507): Revert to `typealias State = Tensor<Scalar>` after SR-10697 is fixed.
103
96
public struct State : Equatable , Differentiable , VectorProtocol , KeyPathIterable {
104
97
public var value : Tensor < Scalar >
@@ -124,6 +117,11 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
124
117
self . bias = Tensor ( zeros: [ hiddenSize] )
125
118
}
126
119
120
+ /// Returns a zero-valued state with shape compatible with the provided input.
121
+ public func zeroState( for input: Tensor < Scalar > ) -> State {
122
+ State ( Tensor ( zeros: [ input. shape [ 0 ] , weight. shape [ 1 ] ] ) )
123
+ }
124
+
127
125
/// Returns the output obtained from applying the layer to the given input.
128
126
///
129
127
/// - Parameter input: The input to the layer.
@@ -189,14 +187,6 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
189
187
return fusedBias. slice ( lowerBounds: [ 3 * hiddenSize] , upperBounds: [ 4 * hiddenSize] )
190
188
}
191
189
192
- @noDerivative public var stateShape : TensorShape {
193
- TensorShape ( [ 1 , fusedWeight. shape [ 1 ] / 4 ] )
194
- }
195
-
196
- public var zeroState : State {
197
- State ( cell: Tensor ( zeros: stateShape) , hidden: Tensor ( zeros: stateShape) )
198
- }
199
-
200
190
public typealias TimeStepInput = Tensor < Scalar >
201
191
public typealias TimeStepOutput = State
202
192
public typealias Input = RNNCellInput < TimeStepInput , State >
@@ -223,6 +213,14 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
223
213
}
224
214
}
225
215
216
+ /// Returns a zero-valued state with shape compatible with the provided input.
217
+ public func zeroState( for input: Tensor < Scalar > ) -> State {
218
+ let hiddenSize = fusedWeight. shape [ 1 ] / 4
219
+ return State (
220
+ cell: Tensor ( zeros: [ input. shape [ 0 ] , hiddenSize] ) ,
221
+ hidden: Tensor ( zeros: [ input. shape [ 0 ] , hiddenSize] ) )
222
+ }
223
+
226
224
/// Returns the output obtained from applying the layer to the given input.
227
225
///
228
226
/// - Parameter input: The input to the layer.
@@ -272,27 +270,28 @@ public struct RNN<Cell: RNNCell>: Layer {
272
270
self . cell = cell ( )
273
271
}
274
272
275
- @differentiable ( wrt: ( self , input ) , vjp: _vjpCallAsFunction ( _: initialState: ) )
273
+ @differentiable ( wrt: ( self , inputs ) , vjp: _vjpCallAsFunction ( _: initialState: ) )
276
274
public func callAsFunction(
277
- _ input : [ Cell . TimeStepInput ] ,
275
+ _ inputs : [ Cell . TimeStepInput ] ,
278
276
initialState: Cell . State
279
277
) -> [ Cell . TimeStepOutput ] {
278
+ if inputs. isEmpty { return [ Cell . TimeStepOutput] ( ) }
280
279
var currentHiddenState = initialState
281
280
var timeStepOutputs : [ Cell . TimeStepOutput ] = [ ]
282
- for timestep in input {
283
- let output = cell ( input: timestep , state: currentHiddenState)
281
+ for timeStepInput in inputs {
282
+ let output = cell ( input: timeStepInput , state: currentHiddenState)
284
283
currentHiddenState = output. state
285
284
timeStepOutputs. append ( output. output)
286
285
}
287
286
return timeStepOutputs
288
287
}
289
288
290
- @differentiable ( wrt: ( self , input ) )
289
+ @differentiable ( wrt: ( self , inputs ) )
291
290
public func call(
292
- _ input : [ Cell . TimeStepInput ] ,
291
+ _ inputs : [ Cell . TimeStepInput ] ,
293
292
initialState: Cell . State
294
293
) -> [ Cell . TimeStepOutput ] {
295
- callAsFunction ( input , initialState: initialState)
294
+ callAsFunction ( inputs , initialState: initialState)
296
295
}
297
296
298
297
@usableFromInline
@@ -303,7 +302,7 @@ public struct RNN<Cell: RNNCell>: Layer {
303
302
( Array < Cell . TimeStepOutput > . TangentVector )
304
303
-> ( TangentVector , Array < Cell . TimeStepInput > . TangentVector ) ) {
305
304
let timeStepCount = inputs. count
306
- var currentHiddenState = cell. zeroState
305
+ var currentHiddenState = cell. zeroState ( for : inputs [ 0 ] )
307
306
var timeStepOutputs : [ Cell . TimeStepOutput ] = [ ]
308
307
timeStepOutputs. reserveCapacity ( timeStepCount)
309
308
var backpropagators : [ Cell . Backpropagator ] = [ ]
@@ -334,23 +333,25 @@ public struct RNN<Cell: RNNCell>: Layer {
334
333
335
334
@differentiable
336
335
public func callAsFunction( _ inputs: [ Cell . TimeStepInput] ) - > [ Cell . TimeStepOutput] {
337
- return self ( inputs, initialState: withoutDerivative ( at: cell. zeroState) )
336
+ let initialState = withoutDerivative ( at: cell. zeroState ( for: inputs [ 0 ] ) )
337
+ return self ( inputs, initialState: withoutDerivative ( at: initialState) )
338
338
}
339
339
340
- /* TODO: Uncomment once control flow and differentiation through force unwrapping is supported.
341
340
@differentiable( wrt: ( self , inputs) )
342
- public func lastOutput(from inputs: [Cell.TimeStepInput],
343
- initialState: Cell.State) -> Cell.TimeStepOutput {
344
- precondition(!inputs.isEmpty, "inputs cannot be empty")
345
- return self(inputs, initialState: initialState).last!
341
+ public func lastOutput(
342
+ from inputs: [ Cell . TimeStepInput] ,
343
+ initialState: Cell . State
344
+ ) - > Cell. TimeStepOutput {
345
+ precondition ( !inputs. isEmpty, " 'inputs' must be non-empty. " )
346
+ return self ( inputs, initialState: initialState) [ withoutDerivative ( at: inputs. count - 1 ) ]
346
347
}
347
348
348
349
@differentiable( wrt: ( self , inputs) )
349
350
public func lastOutput( from inputs: [ Cell . TimeStepInput] ) - > Cell. TimeStepOutput {
350
- precondition(!inputs.isEmpty, "inputs cannot be empty")
351
- return self(inputs, initialState: cell.zeroState).last!
351
+ precondition ( !inputs. isEmpty, " 'inputs' must be non-empty. " )
352
+ let initialState = withoutDerivative ( at: cell. zeroState ( for: inputs [ 0 ] ) )
353
+ return lastOutput ( from: inputs, initialState: initialState)
352
354
}
353
- */
354
355
}
355
356
356
357
extension RNN: Equatable where Cell: Equatable { }
0 commit comments