Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Fix derivative of RNN.callAsFunction(_:initialState:). #660

Merged
merged 1 commit into from
Feb 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Sources/TensorFlow/Layers/Recurrent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ public struct RNN<Cell: RNNCell>: Layer {
-> (TangentVector, Array<Cell.TimeStepInput>.TangentVector)
) {
let timeStepCount = inputs.count
var currentHiddenState = cell.zeroState(for: inputs[0])
var currentHiddenState = initialState
var timeStepOutputs: [Cell.TimeStepOutput] = []
timeStepOutputs.reserveCapacity(timeStepCount)
var backpropagators: [Cell.Backpropagator] = []
Expand Down Expand Up @@ -411,7 +411,7 @@ public struct RNN<Cell: RNNCell>: Layer {
@differentiable
public func callAsFunction(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] {
let initialState = withoutDerivative(at: cell.zeroState(for: inputs[0]))
return self(inputs, initialState: withoutDerivative(at: initialState))
return self(inputs, initialState: initialState)
}

@differentiable(wrt: (self, inputs))
Expand Down
98 changes: 77 additions & 21 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1139,10 +1139,9 @@ final class LayerTests: XCTestCase {
func testRNN() {
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
let rnn = RNN(SimpleRNNCell<Float>(inputSize: 4, hiddenSize: 4,
seed: (0xFeed, 0xBeef)))
let rnn = RNN(SimpleRNNCell<Float>(inputSize: 4, hiddenSize: 4, seed: (0xFeed, 0xBeef)))
withTensorLeakChecking {
let (outputs, _) = valueWithPullback(at: rnn, inputs) { rnn, inputs in
let (outputs, pullback) = valueWithPullback(at: rnn, inputs) { rnn, inputs in
return rnn(inputs)
}
assertEqual(
Expand All @@ -1152,29 +1151,29 @@ final class LayerTests: XCTestCase {
[ 0.23758979, 0.32101023, -0.20359215, -0.1787096],
[ 0.24337786, 0.3389194, -0.21143384, -0.1675081]],
accuracy: 1e-6)
let (𝛁rnn, _) = pullback(.init(inputs.map { SimpleRNNCell<Float>.State($0) }))
// TODO: Verify that RNN gradients are correct using a reference implementation.
XCTAssertEqual(𝛁rnn.cell.weight,
[[ 0.0, 0.0, 0.0, 0.0],
[-0.014372801, 0.03128201, 0.07844338, 0.08569162],
[-0.028745603, 0.06256402, 0.15688676, 0.17138325],
[-0.043118402, 0.09384604, 0.2353301, 0.25707486],
[-0.019920545, 0.05355064, 0.13140751, 0.15169607],
[-0.024906494, 0.06562942, 0.15947133, 0.18506715],
[ 0.016476292, -0.042923313, -0.10459379, -0.12082438],
[ 0.013913135, -0.040882945, -0.100636974, -0.11757788]])
XCTAssertEqual(𝛁rnn.cell.bias, [-0.14372802, 0.31282014, 0.78443366, 0.8569162])
}
// TODO: Figure out why the following is numerically unstable.
// let (𝛁rnn, _) = pullback(.init(inputs.map { SimpleRNNCell<Float>.State($0) }))
// XCTAssertEqual(𝛁rnn.cell.weight,
// [[ 0.0, 0.0, 0.0, 0.0],
// [ 0.02496884, 0.06694733, 0.07978788, -0.022378458],
// [ 0.04993768, 0.13389467, 0.15957576, -0.044756915],
// [ 0.07490652, 0.20084201, 0.23936366, -0.06713537],
// [ 0.0, 0.0, 0.0, 0.0],
// [ 0.0, 0.0, 0.0, 0.0],
// [ 0.0, 0.0, 0.0, 0.0],
// [ 0.0, 0.0, 0.0, 0.0]])
// XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457])
}

func testLSTM() {
withRandomSeedForTensorFlow((0xFeed, 0xBeef)) {
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
let rnn = RNN(LSTMCell<Float>(inputSize: 4, hiddenSize: 4))
let lstm = RNN(LSTMCell<Float>(inputSize: 4, hiddenSize: 4))
withTensorLeakChecking {
let (outputs, _) = valueWithPullback(at: rnn, inputs) { rnn, inputs in
return rnn(inputs)
let (outputs, pullback) = valueWithPullback(at: lstm, inputs) { lstm, inputs in
return lstm(inputs)
}
assertEqual(
outputs.map { $0.cell.squeezingShape(at: 0) }[0],
Expand All @@ -1190,22 +1189,62 @@ final class LayerTests: XCTestCase {
[ 0.074910110, 0.021107012, -0.049724963, -0.069670826],
[ 0.078670055, 0.022462710, -0.051899005, -0.075331904]],
accuracy: 1e-6)
let (𝛁lstm, _) = pullback(.init(inputs.map { LSTMCell<Float>.State(cell: $0, hidden: $0) }))
// TODO: Verify that LSTM gradients are correct using a reference implementation.
Copy link
Contributor

@marcrasi marcrasi Feb 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm interested in doing this. Should we do something like check some TF python code into this repository that reproduces the gradients?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good!

XCTAssertEqual(𝛁lstm.cell.fusedWeight,
[[ 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0],
[ 0.00012854872, 0.0013978262, -0.0064465487, -0.011084668,
0.001252454, 0.04924231, 0.1023805, 0.12028344,
3.0466243e-05, 0.0006108698, -0.0027553777, -0.0048254076,
0.00011663328, 0.0006076429, -0.0026212593, -0.003298801],
[ 0.00025709745, 0.0027956525, -0.0128930975, -0.022169337,
0.002504908, 0.09848462, 0.204761, 0.24056688,
6.0932485e-05, 0.0012217396, -0.0055107553, -0.009650815,
0.00023326656, 0.0012152859, -0.0052425186, -0.006597602],
[ 0.00038564618, 0.0041934787, -0.019339647, -0.03325401,
0.003757362, 0.14772694, 0.3071415, 0.36085027,
9.1398724e-05, 0.0018326094, -0.008266133, -0.014476223,
0.00034989987, 0.0018229289, -0.007863778, -0.009896403],
[ 2.7438582e-05, 0.00056287006, -0.0024641054, -0.004909771,
0.00028730888, 0.019899525, 0.0410647, 0.050809838,
1.6388643e-05, 0.0003807871, -0.0017060185, -0.0030680457,
3.7163307e-05, 0.00029245956, -0.0012287574, -0.0018296391],
[ 7.462907e-06, 0.00015513944, -0.00067863404, -0.0013554879,
7.8164114e-05, 0.0054846643, 0.011315366, 0.014021275,
4.4683666e-06, 0.000105314364, -0.0004715426, -0.0008500715,
1.0132099e-05, 8.078475e-05, -0.00033919656, -0.00050792634],
[ -1.818974e-05, -0.0003736046, 0.0016354292, 0.0032592756,
-0.00019047626, -0.013208302, -0.027256217, -0.033727698,
-1.0870848e-05, -0.00025284386, 0.0011327572, 0.002037401,
-2.465073e-05, -0.00019415902, 0.000815714, 0.0012148761],
[-2.3125162e-05, -0.0004929221, 0.0021531105, 0.00431989,
-0.00024233271, -0.017425863, -0.035934873, -0.0446482,
-1.3914708e-05, -0.00033675073, 0.0015061073, 0.0027270648,
-3.15488e-05, -0.0002577127, 0.001080812, 0.0016348549]])
XCTAssertEqual(𝛁lstm.cell.fusedBias,
[0.0012854873, 0.013978262, -0.06446548, -0.11084669,
0.01252454, 0.49242306, 1.023805, 1.2028344,
0.0003046624, 0.0061086984, -0.027553776, -0.048254073,
0.0011663327, 0.006076429, -0.02621259, -0.032988008])
}
}
}

func testGRU() {
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
let rnn = RNN(GRUCell<Float>(
let gru = RNN(GRUCell<Float>(
inputSize: 4,
hiddenSize: 4,
weightInitializer: glorotUniform(seed: (0xFeed, 0xBeef)),
biasInitializer: zeros())
)
withTensorLeakChecking {
let (outputs, _) = valueWithPullback(at: rnn, inputs) { rnn, inputs in
return rnn(inputs)
let (outputs, pullback) = valueWithPullback(at: gru, inputs) { gru, inputs in
return gru(inputs)
}
assertEqual(
outputs.map { $0.hidden }[0],
Expand All @@ -1214,6 +1253,23 @@ final class LayerTests: XCTestCase {
[0.2230835, 0.2230835, 0.2230835, 0.2230835],
[0.2383619, 0.2383619, 0.2383619, 0.2383619]],
accuracy: 1e-5)
// TODO: Verify that GRU gradients are correct using a reference implementation.
let (𝛁gru, _) = pullback(.init(inputs.map { GRUCell<Float>.State(hidden: $0) }))
XCTAssertEqual(𝛁gru.cell.updateWeight1,
[[ 0.0], [-0.040293925], [ -0.08058785], [ -0.12088178]])
XCTAssertEqual(𝛁gru.cell.updateWeight2,
[[-0.056792725], [-0.056792725], [-0.056792725], [-0.056792725]])
XCTAssertEqual(𝛁gru.cell.resetWeight1,
[[ 0.0], [0.0039126356], [ 0.007825271], [ 0.011737906]])
XCTAssertEqual(𝛁gru.cell.resetWeight2,
[[0.0069182813], [0.0069182813], [0.0069182813], [0.0069182813]])
XCTAssertEqual(𝛁gru.cell.outputWeight1,
[[ 0.0], [0.1221647], [0.2443294], [0.3664941]])
XCTAssertEqual(𝛁gru.cell.outputWeight2,
[[0.08078343], [0.08078343], [0.08078343], [0.08078343]])
XCTAssertEqual(𝛁gru.cell.updateBias, [-0.016739635, -0.04493352, -0.13216142, -0.20910467])
XCTAssertEqual(𝛁gru.cell.resetBias, [ 0.023218961, -0.024303729, 0.010057628, 0.030153492])
XCTAssertEqual(𝛁gru.cell.outputBias, [ 0.06667276, 0.115095116, 0.39864573, 0.6412333])
}
}

Expand Down