Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add Recurrent Layers #71

Merged
merged 26 commits into from
Apr 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions Sources/DeepLearning/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1283,3 +1283,123 @@ public extension RNNCell {
return applied(to: RNNCellInput(input: input, state: state))
}
}

/// A Simple RNN Cell.
public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
public var weight: Tensor<Scalar>
public var bias: Tensor<Scalar>

@noDerivative public var stateShape: TensorShape {
return TensorShape([1, weight.shape[1]])
}

@differentiable
public var zeroState: Tensor<Scalar> {
return Tensor(zeros: stateShape)
}

public typealias State = Tensor<Scalar>
public typealias TimeStepInput = Tensor<Scalar>
public typealias TimeStepOutput = State
public typealias Input = RNNCellInput<TimeStepInput, State>
public typealias Output = RNNCellOutput<TimeStepOutput, State>

/// Creates a `SimpleRNNCell` with the specified input size and hidden state size.
///
/// - Parameters:
/// - inputSize: The number of features in 2-D input tensors.
/// - hiddenSize: The number of features in 2-D hidden states.
public init(inputSize: Int, hiddenSize: Int) {
let concatenatedInputSize = inputSize + hiddenSize
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize])
self.bias = Tensor(zeros: [hiddenSize])
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameters:
/// - input: The input to the layer.
/// - context: The contextual information for the layer application, e.g. the current learning
/// phase.
/// - Returns: The hidden state.
@differentiable
public func applied(to input: Input) -> Output {
let concatenatedInput = input.input.concatenated(with: input.state)
let newState = matmul(concatenatedInput, weight) + bias
return Output(output: newState, state: newState)
}
}

/// An LSTM Cell.
public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell {
public var inputWeight, updateWeight, forgetWeight, outputWeight: Tensor<Scalar>
public var inputBias, updateBias, forgetBias, outputBias: Tensor<Scalar>

@noDerivative public var stateShape: TensorShape {
return TensorShape([1, inputWeight.shape[1]])
}

@differentiable
public var zeroState: State {
return State(cell: Tensor(zeros: stateShape), hidden: Tensor(zeros: stateShape))
}

public typealias TimeStepInput = Tensor<Scalar>
public typealias TimeStepOutput = State
public typealias Input = RNNCellInput<TimeStepInput, State>
public typealias Output = RNNCellOutput<TimeStepOutput, State>

/// Creates a `LSTMCell` with the specified input size and hidden state size.
///
/// - Parameters:
/// - inputSize: The number of features in 2-D input tensors.
/// - hiddenSize: The number of features in 2-D hidden states.
public init(inputSize: Int, hiddenSize: Int) {
let concatenatedInputSize = inputSize + hiddenSize
let gateWeightShape = TensorShape([concatenatedInputSize, hiddenSize])
let gateBiasShape = TensorShape([hiddenSize])
self.inputWeight = Tensor(glorotUniform: gateWeightShape)
self.inputBias = Tensor(zeros: gateBiasShape)
self.updateWeight = Tensor(glorotUniform: gateWeightShape)
self.updateBias = Tensor(zeros: gateBiasShape)
self.forgetWeight = Tensor(glorotUniform: gateWeightShape)
self.forgetBias = Tensor(ones: gateBiasShape)
self.outputWeight = Tensor(glorotUniform: gateWeightShape)
self.outputBias = Tensor(zeros: gateBiasShape)
}

public struct State: Differentiable {
public var cell: Tensor<Scalar>
public var hidden: Tensor<Scalar>

@differentiable
public init(cell: Tensor<Scalar>, hidden: Tensor<Scalar>) {
self.cell = cell
self.hidden = hidden
}
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameters:
/// - input: The input to the layer.
/// - context: The contextual information for the layer application, e.g. the current learning
/// phase.
/// - Returns: The hidden state.
@differentiable
public func applied(to input: Input) -> Output {
let gateInput = input.input.concatenated(with: input.state.hidden, alongAxis: 1)

let inputGate = sigmoid(matmul(gateInput, inputWeight) + inputBias)
let updateGate = tanh(matmul(gateInput, updateWeight) + updateBias)
let forgetGate = sigmoid(matmul(gateInput, forgetWeight) + forgetBias)
let outputGate = sigmoid(matmul(gateInput, outputWeight) + outputBias)

let newCellState = input.state.cell * forgetGate + inputGate * updateGate
let newHiddenState = tanh(newCellState) * outputGate

let newState = State(cell: newCellState, hidden: newHiddenState)

return Output(output: newState, state: newState)
}
}
16 changes: 15 additions & 1 deletion Tests/DeepLearningTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output.shape, expected)
}

func testSimpleRNNCell() {
let weight = Tensor<Float>(ones: [7, 5]) * Tensor<Float>([0.3333, 1, 0.3333, 1, 0.3333])
let bias = Tensor<Float>(ones: [5])
var cell = SimpleRNNCell<Float>(inputSize: 2, hiddenSize: 5)
cell.weight = weight
cell.bias = bias
let state = Tensor<Float>(ones: [1, 5]) * Tensor<Float>([1, 0.2, 0.5, 2, 0.6])
let input = Tensor<Float>(ones: [1, 2]) * Tensor<Float>([0.3, 0.7])
let output = cell.applied(to: input, state: state).state
let expected = Tensor<Float>([[2.76649, 6.2999997, 2.76649, 6.2999997, 2.76649]])
XCTAssertEqual(output, expected)
}

static var allTests = [
("testConv1D", testConv1D),
("testMaxPool1D", testMaxPool1D),
Expand All @@ -90,6 +103,7 @@ final class LayerTests: XCTestCase {
("testGlobalAvgPool2D", testGlobalAvgPool2D),
("testGlobalAvgPool3D", testGlobalAvgPool3D),
("testReshape", testReshape),
("testFlatten", testFlatten)
("testFlatten", testFlatten),
("testSimpleRNNCell", testSimpleRNNCell)
]
}