Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 3f71684

Browse files
authored
Add RNN wrapper for Cells (#105)
* Add RNN wrapper for Cells * Apply suggestions from code review Co-Authored-By: tanmayb123 <[email protected]> * Implement RNN pullback and add tests. * Minor improvements. * Fix a typo. * Assert -> precondition.
1 parent cfefd63 commit 3f71684

File tree

2 files changed

+130
-10
lines changed

2 files changed

+130
-10
lines changed

Sources/DeepLearning/Layer.swift

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,7 @@ public extension RNNCell {
12811281
}
12821282

12831283
/// A Simple RNN Cell.
1284-
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1284+
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNumeric {
12851285
public var weight: Tensor<Scalar>
12861286
public var bias: Tensor<Scalar>
12871287

@@ -1304,9 +1304,13 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
13041304
/// - Parameters:
13051305
/// - inputSize: The number of features in 2-D input tensors.
13061306
/// - hiddenSize: The number of features in 2-D hidden states.
1307-
public init(inputSize: Int, hiddenSize: Int) {
1307+
/// - seed: The random seed for initialization. The default value is random.
1308+
public init(inputSize: Int, hiddenSize: Int,
1309+
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
1310+
Int64.random(in: Int64.min..<Int64.max))) {
13081311
let concatenatedInputSize = inputSize + hiddenSize
1309-
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize])
1312+
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize],
1313+
seed: seed)
13101314
self.bias = Tensor(zeros: [hiddenSize])
13111315
}
13121316

@@ -1326,7 +1330,7 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
13261330
}
13271331

13281332
/// An LSTM Cell.
1329-
public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
1333+
public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNumeric {
13301334
public var inputWeight, updateWeight, forgetWeight, outputWeight: Tensor<Scalar>
13311335
public var inputBias, updateBias, forgetBias, outputBias: Tensor<Scalar>
13321336

@@ -1348,17 +1352,19 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
13481352
/// - Parameters:
13491353
/// - inputSize: The number of features in 2-D input tensors.
13501354
/// - hiddenSize: The number of features in 2-D hidden states.
1351-
public init(inputSize: Int, hiddenSize: Int) {
1355+
public init(inputSize: Int, hiddenSize: Int,
1356+
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
1357+
Int64.random(in: Int64.min..<Int64.max))) {
13521358
let concatenatedInputSize = inputSize + hiddenSize
13531359
let gateWeightShape = TensorShape([concatenatedInputSize, hiddenSize])
13541360
let gateBiasShape = TensorShape([hiddenSize])
1355-
self.inputWeight = Tensor(glorotUniform: gateWeightShape)
1361+
self.inputWeight = Tensor(glorotUniform: gateWeightShape, seed: seed)
13561362
self.inputBias = Tensor(zeros: gateBiasShape)
1357-
self.updateWeight = Tensor(glorotUniform: gateWeightShape)
1363+
self.updateWeight = Tensor(glorotUniform: gateWeightShape, seed: seed)
13581364
self.updateBias = Tensor(zeros: gateBiasShape)
1359-
self.forgetWeight = Tensor(glorotUniform: gateWeightShape)
1365+
self.forgetWeight = Tensor(glorotUniform: gateWeightShape, seed: seed)
13601366
self.forgetBias = Tensor(ones: gateBiasShape)
1361-
self.outputWeight = Tensor(glorotUniform: gateWeightShape)
1367+
self.outputWeight = Tensor(glorotUniform: gateWeightShape, seed: seed)
13621368
self.outputBias = Tensor(zeros: gateBiasShape)
13631369
}
13641370

@@ -1397,3 +1403,91 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
13971403
return Output(output: newState, state: newState)
13981404
}
13991405
}
1406+
1407+
public struct RNN<Cell: RNNCell>: Layer {
1408+
public typealias Input = [Cell.TimeStepInput]
1409+
public typealias Output = [Cell.TimeStepOutput]
1410+
1411+
public var cell: Cell
1412+
1413+
public init(_ cell: @autoclosure () -> Cell) {
1414+
self.cell = cell()
1415+
}
1416+
1417+
@differentiable(wrt: (self, input), vjp: _vjpCall(_:initialState:))
1418+
public func call(_ input: [Cell.TimeStepInput],
1419+
initialState: Cell.State) -> [Cell.TimeStepOutput] {
1420+
var currentHiddenState = initialState
1421+
var timeStepOutputs: [Cell.TimeStepOutput] = []
1422+
for timestep in input {
1423+
let output = cell(input: timestep, state: currentHiddenState)
1424+
currentHiddenState = output.state
1425+
timeStepOutputs.append(output.output)
1426+
}
1427+
return timeStepOutputs
1428+
}
1429+
1430+
@usableFromInline
1431+
internal func _vjpCall(
1432+
_ inputs: [Cell.TimeStepInput], initialState: Cell.State
1433+
) -> ([Cell.TimeStepOutput],
1434+
(Array<Cell.TimeStepOutput>.CotangentVector)
1435+
-> (CotangentVector, Array<Cell.TimeStepInput>.CotangentVector)) {
1436+
let timeStepCount = inputs.count
1437+
var currentHiddenState = cell.zeroState
1438+
var timeStepOutputs: [Cell.TimeStepOutput] = []
1439+
timeStepOutputs.reserveCapacity(timeStepCount)
1440+
var backpropagators: [Cell.Backpropagator] = []
1441+
backpropagators.reserveCapacity(timeStepCount)
1442+
for timestep in inputs {
1443+
let (output, backpropagator) =
1444+
cell.appliedForBackpropagation(to: .init(input: timestep,
1445+
state: currentHiddenState))
1446+
currentHiddenState = output.state
1447+
timeStepOutputs.append(output.output)
1448+
backpropagators.append(backpropagator)
1449+
}
1450+
return (timeStepOutputs, { 𝛁outputs in
1451+
precondition(𝛁outputs.base.count == timeStepCount,
1452+
"The number of output gradients must equal the number of time steps")
1453+
var 𝛁cell = Cell.CotangentVector.zero
1454+
var 𝛁state = Cell.State.CotangentVector.zero
1455+
var reversed𝛁inputs: [Cell.TimeStepInput.CotangentVector] = []
1456+
reversed𝛁inputs.reserveCapacity(timeStepCount)
1457+
for (𝛁output, backpropagator) in zip(𝛁outputs.base, backpropagators).reversed() {
1458+
let (new𝛁cell, 𝛁input) = backpropagator(.init(output: 𝛁output, state: 𝛁state))
1459+
𝛁cell = new𝛁cell
1460+
𝛁state = 𝛁input.state
1461+
reversed𝛁inputs.append(𝛁input.input)
1462+
}
1463+
return (.init(cell: 𝛁cell), .init(Array(reversed𝛁inputs.reversed())))
1464+
})
1465+
}
1466+
1467+
@differentiable(wrt: (self, inputs))
1468+
public func call(_ inputs: [Cell.TimeStepInput]) -> [Cell.TimeStepOutput] {
1469+
return self(inputs, initialState: cell.zeroState.withoutDerivative())
1470+
}
1471+
1472+
/* TODO: Uncomment once control flow and differentiation through force unwrapping is supported.
1473+
@differentiable(wrt: (self, inputs))
1474+
public func lastOutput(from inputs: [Cell.TimeStepInput],
1475+
initialState: Cell.State) -> Cell.TimeStepOutput {
1476+
precondition(!inputs.isEmpty, "inputs cannot be empty")
1477+
return self(inputs, initialState: initialState).last!
1478+
}
1479+
1480+
@differentiable(wrt: (self, inputs))
1481+
public func lastOutput(from inputs: [Cell.TimeStepInput]) -> Cell.TimeStepOutput {
1482+
precondition(!inputs.isEmpty, "inputs cannot be empty")
1483+
return self(inputs, initialState: cell.zeroState).last!
1484+
}
1485+
*/
1486+
}
1487+
1488+
extension RNN: Equatable where Cell: Equatable {}
1489+
extension RNN: AdditiveArithmetic where Cell: AdditiveArithmetic {}
1490+
extension RNN: VectorNumeric where Cell: VectorNumeric {}
1491+
1492+
public typealias SimpleRNN<Scalar: TensorFlowFloatingPoint> = RNN<SimpleRNNCell<Scalar>>
1493+
public typealias LSTM<Scalar: TensorFlowFloatingPoint> = RNN<LSTMCell<Scalar>>

Tests/DeepLearningTests/LayerTests.swift

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,31 @@ final class LayerTests: XCTestCase {
9595
XCTAssertEqual(output, expected)
9696
}
9797

98+
func testRNN() {
99+
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
100+
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
101+
let rnn = RNN(SimpleRNNCell<Float>(inputSize: 4, hiddenSize: 4,
102+
seed: (0xFeedBeef, 0xDeadBeef)))
103+
let (outputs, pullback) = rnn.valueWithPullback(at: inputs) { rnn, inputs in
104+
return rnn(inputs)
105+
}
106+
XCTAssertEqual(outputs, [[[-0.0026294366, -0.0058668107, 0.04495003, 0.20311214]],
107+
[[ 0.06788494, 0.050665878, 0.02415526, 0.09249911]],
108+
[[ 0.06621192, 0.009049267, 0.065047316, 0.11534518]],
109+
[[ 0.05612204, 0.00022032857, 0.05407162, 0.09784105]]])
110+
let (𝛁rnn, 𝛁inputs) = pullback(.init(inputs))
111+
XCTAssertEqual(𝛁rnn.cell.weight,
112+
[[ 0.0, 0.0, 0.0, 0.0],
113+
[-0.0051278225, 0.0013102926, 0.00740262, 0.018119661],
114+
[ -0.010255645, 0.0026205853, 0.01480524, 0.036239322],
115+
[ -0.015383467, 0.003930878, 0.02220786, 0.054358985],
116+
[ 0.0, 0.0, 0.0, 0.0],
117+
[ 0.0, 0.0, 0.0, 0.0],
118+
[ 0.0, 0.0, 0.0, 0.0],
119+
[ 0.0, 0.0, 0.0, 0.0]])
120+
XCTAssertEqual(𝛁rnn.cell.bias, [-0.051278222, 0.013102926, 0.0740262, 0.18119662])
121+
}
122+
98123
static var allTests = [
99124
("testConv1D", testConv1D),
100125
("testMaxPool1D", testMaxPool1D),
@@ -104,6 +129,7 @@ final class LayerTests: XCTestCase {
104129
("testGlobalAvgPool3D", testGlobalAvgPool3D),
105130
("testReshape", testReshape),
106131
("testFlatten", testFlatten),
107-
("testSimpleRNNCell", testSimpleRNNCell)
132+
("testSimpleRNNCell", testSimpleRNNCell),
133+
("testRNN", testRNN)
108134
]
109135
}

0 commit comments

Comments
 (0)