Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 25c7cfe

Browse files
eaplataniosmarcrasi
authored andcommitted
Fixed a couple RNN bugs. (#522)
* Marked initializer VJP parameters as '__owned'. * Fixed a couple RNN bugs. * Reverted refactoring.
1 parent 5eedf8c commit 25c7cfe

File tree

1 file changed

+36
-35
lines changed

1 file changed

+36
-35
lines changed

Sources/TensorFlow/Layers/Recurrent.swift

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ public protocol RNNCell: Layer
6060
associatedtype TimeStepOutput: Differentiable
6161
/// The state that may be preserved across time steps.
6262
associatedtype State: Differentiable
63-
/// The zero state.
64-
var zeroState: State { get }
63+
64+
/// Returns a zero-valued state with shape compatible with the provided input.
65+
func zeroState(for input: TimeStepInput) -> State
6566
}
6667

6768
public extension RNNCell {
@@ -91,14 +92,6 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
9192
public var weight: Tensor<Scalar>
9293
public var bias: Tensor<Scalar>
9394

94-
@noDerivative public var stateShape: TensorShape {
95-
TensorShape([1, weight.shape[1]])
96-
}
97-
98-
public var zeroState: State {
99-
State(Tensor(zeros: stateShape))
100-
}
101-
10295
// TODO(TF-507): Revert to `typealias State = Tensor<Scalar>` after SR-10697 is fixed.
10396
public struct State: Equatable, Differentiable, VectorProtocol, KeyPathIterable {
10497
public var value: Tensor<Scalar>
@@ -124,6 +117,11 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
124117
self.bias = Tensor(zeros: [hiddenSize])
125118
}
126119

120+
/// Returns a zero-valued state with shape compatible with the provided input.
121+
public func zeroState(for input: Tensor<Scalar>) -> State {
122+
State(Tensor(zeros: [input.shape[0], weight.shape[1]]))
123+
}
124+
127125
/// Returns the output obtained from applying the layer to the given input.
128126
///
129127
/// - Parameter input: The input to the layer.
@@ -189,14 +187,6 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
189187
return fusedBias.slice(lowerBounds: [3 * hiddenSize], upperBounds: [4 * hiddenSize])
190188
}
191189

192-
@noDerivative public var stateShape: TensorShape {
193-
TensorShape([1, fusedWeight.shape[1] / 4])
194-
}
195-
196-
public var zeroState: State {
197-
State(cell: Tensor(zeros: stateShape), hidden: Tensor(zeros: stateShape))
198-
}
199-
200190
public typealias TimeStepInput = Tensor<Scalar>
201191
public typealias TimeStepOutput = State
202192
public typealias Input = RNNCellInput<TimeStepInput, State>
@@ -223,6 +213,14 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
223213
}
224214
}
225215

216+
/// Returns a zero-valued state with shape compatible with the provided input.
217+
public func zeroState(for input: Tensor<Scalar>) -> State {
218+
let hiddenSize = fusedWeight.shape[1] / 4
219+
return State(
220+
cell: Tensor(zeros: [input.shape[0], hiddenSize]),
221+
hidden: Tensor(zeros: [input.shape[0], hiddenSize]))
222+
}
223+
226224
/// Returns the output obtained from applying the layer to the given input.
227225
///
228226
/// - Parameter input: The input to the layer.
@@ -272,27 +270,28 @@ public struct RNN<Cell: RNNCell>: Layer {
272270
self.cell = cell()
273271
}
274272

275-
@differentiable(wrt: (self, input), vjp: _vjpCallAsFunction(_:initialState:))
273+
@differentiable(wrt: (self, inputs), vjp: _vjpCallAsFunction(_:initialState:))
276274
public func callAsFunction(
277-
_ input: [Cell.TimeStepInput],
275+
_ inputs: [Cell.TimeStepInput],
278276
initialState: Cell.State
279277
) -> [Cell.TimeStepOutput] {
278+
if inputs.isEmpty { return [Cell.TimeStepOutput]() }
280279
var currentHiddenState = initialState
281280
var timeStepOutputs: [Cell.TimeStepOutput] = []
282-
for timestep in input {
283-
let output = cell(input: timestep, state: currentHiddenState)
281+
for timeStepInput in inputs {
282+
let output = cell(input: timeStepInput, state: currentHiddenState)
284283
currentHiddenState = output.state
285284
timeStepOutputs.append(output.output)
286285
}
287286
return timeStepOutputs
288287
}
289288

290-
@differentiable(wrt: (self, input))
289+
@differentiable(wrt: (self, inputs))
291290
public func call(
292-
_ input: [Cell.TimeStepInput],
291+
_ inputs: [Cell.TimeStepInput],
293292
initialState: Cell.State
294293
) -> [Cell.TimeStepOutput] {
295-
callAsFunction(input, initialState: initialState)
294+
callAsFunction(inputs, initialState: initialState)
296295
}
297296

298297
@usableFromInline
@@ -303,7 +302,7 @@ public struct RNN<Cell: RNNCell>: Layer {
303302
(Array<Cell.TimeStepOutput>.TangentVector)
304303
-> (TangentVector, Array<Cell.TimeStepInput>.TangentVector)) {
305304
let timeStepCount = inputs.count
306-
var currentHiddenState = cell.zeroState
305+
var currentHiddenState = cell.zeroState(for: inputs[0])
307306
var timeStepOutputs: [Cell.TimeStepOutput] = []
308307
timeStepOutputs.reserveCapacity(timeStepCount)
309308
var backpropagators: [Cell.Backpropagator] = []
@@ -334,23 +333,25 @@ public struct RNN<Cell: RNNCell>: Layer {
334333

335334
@differentiable
336335
public func callAsFunction(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] {
337-
return self(inputs, initialState: withoutDerivative(at: cell.zeroState))
336+
let initialState = withoutDerivative(at: cell.zeroState(for: inputs[0]))
337+
return self(inputs, initialState: withoutDerivative(at: initialState))
338338
}
339339

340-
/* TODO: Uncomment once control flow and differentiation through force unwrapping is supported.
341340
@differentiable(wrt: (self, inputs))
342-
public func lastOutput(from inputs: [Cell.TimeStepInput],
343-
initialState: Cell.State) -> Cell.TimeStepOutput {
344-
precondition(!inputs.isEmpty, "inputs cannot be empty")
345-
return self(inputs, initialState: initialState).last!
341+
public func lastOutput(
342+
from inputs: [Cell.TimeStepInput],
343+
initialState: Cell.State
344+
) -> Cell.TimeStepOutput {
345+
precondition(!inputs.isEmpty, "'inputs' must be non-empty.")
346+
return self(inputs, initialState: initialState)[withoutDerivative(at: inputs.count - 1)]
346347
}
347348

348349
@differentiable(wrt: (self, inputs))
349350
public func lastOutput(from inputs: [Cell.TimeStepInput]) -> Cell.TimeStepOutput {
350-
precondition(!inputs.isEmpty, "inputs cannot be empty")
351-
return self(inputs, initialState: cell.zeroState).last!
351+
precondition(!inputs.isEmpty, "'inputs' must be non-empty.")
352+
let initialState = withoutDerivative(at: cell.zeroState(for: inputs[0]))
353+
return lastOutput(from: inputs, initialState: initialState)
352354
}
353-
*/
354355
}
355356

356357
extension RNN: Equatable where Cell: Equatable {}

0 commit comments

Comments
 (0)