Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 86b06f6

Browse files
committed
Address review comments.
1 parent 2292e6b commit 86b06f6

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,7 +1306,10 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
13061306
// TODO(TF-507): Revert to `typealias State = Tensor<Scalar>` after
13071307
// SR-10697 is fixed.
13081308
public struct State: Differentiable {
1309-
let state: Tensor<Scalar>
1309+
public let value: Tensor<Scalar>
1310+
public init(_ value: Tensor<Scalar>) {
1311+
self.value = value
1312+
}
13101313
}
13111314

13121315
public typealias TimeStepInput = Tensor<Scalar>
@@ -1334,8 +1337,8 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
13341337
/// - Returns: The hidden state.
13351338
@differentiable
13361339
public func call(_ input: Input) -> Output {
1337-
let concatenatedInput = input.input.concatenated(with: input.state.state, alongAxis: 1)
1338-
let newState = State(state: tanh(matmul(concatenatedInput, weight) + bias))
1340+
let concatenatedInput = input.input.concatenated(with: input.state.value, alongAxis: 1)
1341+
let newState = State(tanh(matmul(concatenatedInput, weight) + bias))
13391342
return Output(output: newState, state: newState)
13401343
}
13411344
}

0 commit comments

Comments
 (0)