Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 94ab4f5

Browse files
authored
Fix derivative of RNN.callAsFunction(_:initialState:). (#660)
Previously, `RNN._vjpcallAsFunction(_:initialState:)` incorrectly used a zero initial state. Now, it uses `initialState` as the initial state. Add `RNN` gradient tests for `SimpleRNNCell`, `LSTMCell`, and `GRUCell`. Todo: verify that gradients are correct using a reference implementation.
1 parent 30296a9 commit 94ab4f5

File tree

2 files changed

+79
-23
lines changed

2 files changed

+79
-23
lines changed

Sources/TensorFlow/Layers/Recurrent.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ public struct RNN<Cell: RNNCell>: Layer {
379379
-> (TangentVector, Array<Cell.TimeStepInput>.TangentVector)
380380
) {
381381
let timeStepCount = inputs.count
382-
var currentHiddenState = cell.zeroState(for: inputs[0])
382+
var currentHiddenState = initialState
383383
var timeStepOutputs: [Cell.TimeStepOutput] = []
384384
timeStepOutputs.reserveCapacity(timeStepCount)
385385
var backpropagators: [Cell.Backpropagator] = []
@@ -411,7 +411,7 @@ public struct RNN<Cell: RNNCell>: Layer {
411411
@differentiable
412412
public func callAsFunction(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] {
413413
let initialState = withoutDerivative(at: cell.zeroState(for: inputs[0]))
414-
return self(inputs, initialState: withoutDerivative(at: initialState))
414+
return self(inputs, initialState: initialState)
415415
}
416416

417417
@differentiable(wrt: (self, inputs))

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,10 +1139,9 @@ final class LayerTests: XCTestCase {
11391139
func testRNN() {
11401140
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
11411141
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
1142-
let rnn = RNN(SimpleRNNCell<Float>(inputSize: 4, hiddenSize: 4,
1143-
seed: (0xFeed, 0xBeef)))
1142+
let rnn = RNN(SimpleRNNCell<Float>(inputSize: 4, hiddenSize: 4, seed: (0xFeed, 0xBeef)))
11441143
withTensorLeakChecking {
1145-
let (outputs, _) = valueWithPullback(at: rnn, inputs) { rnn, inputs in
1144+
let (outputs, pullback) = valueWithPullback(at: rnn, inputs) { rnn, inputs in
11461145
return rnn(inputs)
11471146
}
11481147
assertEqual(
@@ -1152,29 +1151,29 @@ final class LayerTests: XCTestCase {
11521151
[ 0.23758979, 0.32101023, -0.20359215, -0.1787096],
11531152
[ 0.24337786, 0.3389194, -0.21143384, -0.1675081]],
11541153
accuracy: 1e-6)
1154+
let (𝛁rnn, _) = pullback(.init(inputs.map { SimpleRNNCell<Float>.State($0) }))
1155+
// TODO: Verify that RNN gradients are correct using a reference implementation.
1156+
XCTAssertEqual(𝛁rnn.cell.weight,
1157+
[[ 0.0, 0.0, 0.0, 0.0],
1158+
[-0.014372801, 0.03128201, 0.07844338, 0.08569162],
1159+
[-0.028745603, 0.06256402, 0.15688676, 0.17138325],
1160+
[-0.043118402, 0.09384604, 0.2353301, 0.25707486],
1161+
[-0.019920545, 0.05355064, 0.13140751, 0.15169607],
1162+
[-0.024906494, 0.06562942, 0.15947133, 0.18506715],
1163+
[ 0.016476292, -0.042923313, -0.10459379, -0.12082438],
1164+
[ 0.013913135, -0.040882945, -0.100636974, -0.11757788]])
1165+
XCTAssertEqual(𝛁rnn.cell.bias, [-0.14372802, 0.31282014, 0.78443366, 0.8569162])
11551166
}
1156-
// TODO: Figure out why the following is numerically unstable.
1157-
// let (𝛁rnn, _) = pullback(.init(inputs.map { SimpleRNNCell<Float>.State($0) }))
1158-
// XCTAssertEqual(𝛁rnn.cell.weight,
1159-
// [[ 0.0, 0.0, 0.0, 0.0],
1160-
// [ 0.02496884, 0.06694733, 0.07978788, -0.022378458],
1161-
// [ 0.04993768, 0.13389467, 0.15957576, -0.044756915],
1162-
// [ 0.07490652, 0.20084201, 0.23936366, -0.06713537],
1163-
// [ 0.0, 0.0, 0.0, 0.0],
1164-
// [ 0.0, 0.0, 0.0, 0.0],
1165-
// [ 0.0, 0.0, 0.0, 0.0],
1166-
// [ 0.0, 0.0, 0.0, 0.0]])
1167-
// XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457])
11681167
}
11691168

11701169
func testLSTM() {
11711170
withRandomSeedForTensorFlow((0xFeed, 0xBeef)) {
11721171
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
11731172
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
1174-
let rnn = RNN(LSTMCell<Float>(inputSize: 4, hiddenSize: 4))
1173+
let lstm = RNN(LSTMCell<Float>(inputSize: 4, hiddenSize: 4))
11751174
withTensorLeakChecking {
1176-
let (outputs, _) = valueWithPullback(at: rnn, inputs) { rnn, inputs in
1177-
return rnn(inputs)
1175+
let (outputs, pullback) = valueWithPullback(at: lstm, inputs) { lstm, inputs in
1176+
return lstm(inputs)
11781177
}
11791178
assertEqual(
11801179
outputs.map { $0.cell.squeezingShape(at: 0) }[0],
@@ -1190,22 +1189,62 @@ final class LayerTests: XCTestCase {
11901189
[ 0.074910110, 0.021107012, -0.049724963, -0.069670826],
11911190
[ 0.078670055, 0.022462710, -0.051899005, -0.075331904]],
11921191
accuracy: 1e-6)
1192+
let (𝛁lstm, _) = pullback(.init(inputs.map { LSTMCell<Float>.State(cell: $0, hidden: $0) }))
1193+
// TODO: Verify that LSTM gradients are correct using a reference implementation.
1194+
XCTAssertEqual(𝛁lstm.cell.fusedWeight,
1195+
[[ 0.0, 0.0, 0.0, 0.0,
1196+
0.0, 0.0, 0.0, 0.0,
1197+
0.0, 0.0, 0.0, 0.0,
1198+
0.0, 0.0, 0.0, 0.0],
1199+
[ 0.00012854872, 0.0013978262, -0.0064465487, -0.011084668,
1200+
0.001252454, 0.04924231, 0.1023805, 0.12028344,
1201+
3.0466243e-05, 0.0006108698, -0.0027553777, -0.0048254076,
1202+
0.00011663328, 0.0006076429, -0.0026212593, -0.003298801],
1203+
[ 0.00025709745, 0.0027956525, -0.0128930975, -0.022169337,
1204+
0.002504908, 0.09848462, 0.204761, 0.24056688,
1205+
6.0932485e-05, 0.0012217396, -0.0055107553, -0.009650815,
1206+
0.00023326656, 0.0012152859, -0.0052425186, -0.006597602],
1207+
[ 0.00038564618, 0.0041934787, -0.019339647, -0.03325401,
1208+
0.003757362, 0.14772694, 0.3071415, 0.36085027,
1209+
9.1398724e-05, 0.0018326094, -0.008266133, -0.014476223,
1210+
0.00034989987, 0.0018229289, -0.007863778, -0.009896403],
1211+
[ 2.7438582e-05, 0.00056287006, -0.0024641054, -0.004909771,
1212+
0.00028730888, 0.019899525, 0.0410647, 0.050809838,
1213+
1.6388643e-05, 0.0003807871, -0.0017060185, -0.0030680457,
1214+
3.7163307e-05, 0.00029245956, -0.0012287574, -0.0018296391],
1215+
[ 7.462907e-06, 0.00015513944, -0.00067863404, -0.0013554879,
1216+
7.8164114e-05, 0.0054846643, 0.011315366, 0.014021275,
1217+
4.4683666e-06, 0.000105314364, -0.0004715426, -0.0008500715,
1218+
1.0132099e-05, 8.078475e-05, -0.00033919656, -0.00050792634],
1219+
[ -1.818974e-05, -0.0003736046, 0.0016354292, 0.0032592756,
1220+
-0.00019047626, -0.013208302, -0.027256217, -0.033727698,
1221+
-1.0870848e-05, -0.00025284386, 0.0011327572, 0.002037401,
1222+
-2.465073e-05, -0.00019415902, 0.000815714, 0.0012148761],
1223+
[-2.3125162e-05, -0.0004929221, 0.0021531105, 0.00431989,
1224+
-0.00024233271, -0.017425863, -0.035934873, -0.0446482,
1225+
-1.3914708e-05, -0.00033675073, 0.0015061073, 0.0027270648,
1226+
-3.15488e-05, -0.0002577127, 0.001080812, 0.0016348549]])
1227+
XCTAssertEqual(𝛁lstm.cell.fusedBias,
1228+
[0.0012854873, 0.013978262, -0.06446548, -0.11084669,
1229+
0.01252454, 0.49242306, 1.023805, 1.2028344,
1230+
0.0003046624, 0.0061086984, -0.027553776, -0.048254073,
1231+
0.0011663327, 0.006076429, -0.02621259, -0.032988008])
11931232
}
11941233
}
11951234
}
11961235

11971236
func testGRU() {
11981237
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
11991238
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
1200-
let rnn = RNN(GRUCell<Float>(
1239+
let gru = RNN(GRUCell<Float>(
12011240
inputSize: 4,
12021241
hiddenSize: 4,
12031242
weightInitializer: glorotUniform(seed: (0xFeed, 0xBeef)),
12041243
biasInitializer: zeros())
12051244
)
12061245
withTensorLeakChecking {
1207-
let (outputs, _) = valueWithPullback(at: rnn, inputs) { rnn, inputs in
1208-
return rnn(inputs)
1246+
let (outputs, pullback) = valueWithPullback(at: gru, inputs) { gru, inputs in
1247+
return gru(inputs)
12091248
}
12101249
assertEqual(
12111250
outputs.map { $0.hidden }[0],
@@ -1214,6 +1253,23 @@ final class LayerTests: XCTestCase {
12141253
[0.2230835, 0.2230835, 0.2230835, 0.2230835],
12151254
[0.2383619, 0.2383619, 0.2383619, 0.2383619]],
12161255
accuracy: 1e-5)
1256+
// TODO: Verify that GRU gradients are correct using a reference implementation.
1257+
let (𝛁gru, _) = pullback(.init(inputs.map { GRUCell<Float>.State(hidden: $0) }))
1258+
XCTAssertEqual(𝛁gru.cell.updateWeight1,
1259+
[[ 0.0], [-0.040293925], [ -0.08058785], [ -0.12088178]])
1260+
XCTAssertEqual(𝛁gru.cell.updateWeight2,
1261+
[[-0.056792725], [-0.056792725], [-0.056792725], [-0.056792725]])
1262+
XCTAssertEqual(𝛁gru.cell.resetWeight1,
1263+
[[ 0.0], [0.0039126356], [ 0.007825271], [ 0.011737906]])
1264+
XCTAssertEqual(𝛁gru.cell.resetWeight2,
1265+
[[0.0069182813], [0.0069182813], [0.0069182813], [0.0069182813]])
1266+
XCTAssertEqual(𝛁gru.cell.outputWeight1,
1267+
[[ 0.0], [0.1221647], [0.2443294], [0.3664941]])
1268+
XCTAssertEqual(𝛁gru.cell.outputWeight2,
1269+
[[0.08078343], [0.08078343], [0.08078343], [0.08078343]])
1270+
XCTAssertEqual(𝛁gru.cell.updateBias, [-0.016739635, -0.04493352, -0.13216142, -0.20910467])
1271+
XCTAssertEqual(𝛁gru.cell.resetBias, [ 0.023218961, -0.024303729, 0.010057628, 0.030153492])
1272+
XCTAssertEqual(𝛁gru.cell.outputBias, [ 0.06667276, 0.115095116, 0.39864573, 0.6412333])
12171273
}
12181274
}
12191275

0 commit comments

Comments
 (0)