Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 5883af6

Browse files
authored
Workaround for SR-10697. (#123)
Add `State` wrapper struct to `SimpleRNNCell` to work around IRGen crash. TF-507 tracks reverting the change after SR-10697 is fixed.
1 parent ffba693 commit 5883af6

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,11 +1299,19 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
12991299
return TensorShape([1, weight.shape[1]])
13001300
}
13011301

1302-
public var zeroState: Tensor<Scalar> {
1303-
return Tensor(zeros: stateShape)
1302+
public var zeroState: State {
1303+
return State(Tensor(zeros: stateShape))
1304+
}
1305+
1306+
// TODO(TF-507): Revert to `typealias State = Tensor<Scalar>` after
1307+
// SR-10697 is fixed.
1308+
public struct State: Differentiable {
1309+
public let value: Tensor<Scalar>
1310+
public init(_ value: Tensor<Scalar>) {
1311+
self.value = value
1312+
}
13041313
}
13051314

1306-
public typealias State = Tensor<Scalar>
13071315
public typealias TimeStepInput = Tensor<Scalar>
13081316
public typealias TimeStepOutput = State
13091317
public typealias Input = RNNCellInput<TimeStepInput, State>
@@ -1319,8 +1327,7 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
13191327
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
13201328
Int64.random(in: Int64.min..<Int64.max))) {
13211329
let concatenatedInputSize = inputSize + hiddenSize
1322-
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize],
1323-
seed: seed)
1330+
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize], seed: seed)
13241331
self.bias = Tensor(zeros: [hiddenSize])
13251332
}
13261333

@@ -1330,8 +1337,8 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
13301337
/// - Returns: The hidden state.
13311338
@differentiable
13321339
public func call(_ input: Input) -> Output {
1333-
let concatenatedInput = input.input.concatenated(with: input.state, alongAxis: 1)
1334-
let newState = tanh(matmul(concatenatedInput, weight) + bias)
1340+
let concatenatedInput = input.input.concatenated(with: input.state.value, alongAxis: 1)
1341+
let newState = State(tanh(matmul(concatenatedInput, weight) + bias))
13351342
return Output(output: newState, state: newState)
13361343
}
13371344
}

0 commit comments

Comments
 (0)