Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add Recurrent Layers #71

Merged
merged 26 commits into from
Apr 17, 2019
Merged

Add Recurrent Layers #71

merged 26 commits into from
Apr 17, 2019

Conversation

tanmayb123
Copy link
Contributor

@tanmayb123 tanmayb123 commented Apr 2, 2019

#52

@tanmayb123
Copy link
Contributor Author

Would we prefer to have forgetWeight and forgetBias or forgetW and forgetB?

@dan-zheng
Copy link
Member

Would we prefer to have forgetWeight and forgetBias or forgetW and forgetB?

The full names Weight and Bias are preferable.

@tanmayb123
Copy link
Contributor Author

Done.

Copy link
Member

@dan-zheng dan-zheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test to Tests/DeepLearningTests/LayerTests.swift!

@eaplatanios
Copy link
Contributor

It would also be nice if we can have an RNN cell protocol, so that all valid cells have to conform to it and can be used in code that expects RNN cells (e.g., simple RNN, bi-directional RNN, etc).

@tanmayb123
Copy link
Contributor Author

@rxwei What do you think such a protocol would look like? i.e. What kind of functionality would it define?

@tanmayb123 tanmayb123 changed the title Add LSTM Cell Add Recurrent Layers Apr 2, 2019
@rxwei
Copy link
Contributor

rxwei commented Apr 2, 2019

Just a heads-up: I probably won't have time to review this or make suggestions until late evening.

@rxwei
Copy link
Contributor

rxwei commented Apr 6, 2019

Here's a sketch:

public struct RNNInput<StepInput: Differentiable, State: RNNState>: Differentiable {
    public var stepInput: StepInput
    public var previousState: State
    public init(stepInput: StepInput, previousState: State) {
        self.stepInput = stepInput
        self.previousState = previousState
    }
}

public struct RNNState<CellState: Differentiable, HiddenState: Differentiable>: Differentiable {
    public var cell: CellState
    public var hidden: HiddenState
    public init(cell: CellState, hidden: HiddenState) {
        self.cell = cell
        self.hidden = hidden
    }
}

public protocol RNNCell: Layer
    where Input == RNNInput<StepInput, State>, Output == RNNState<CellState, HiddenState> {
    associatedtype StepInput: Differentiable
    associatedtype CellState: Differentiable
    associatedtype HiddenState: Differentiable
    typealias State = Output

    init(inputSize: Int, hiddenSize: Int)
    var zeroState: State { get }
}

public extension RNNCell {
    @differentiable
    func applied(to stepInput: StepInput, _ previousState: State) -> State {
        return applied(to: Input(stepInput: stepInput, previousState: previousState))
    }
}

return State(cell: newCellState, hidden: newHiddenState)
}

public func zeroState() -> State {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more Swifty to define this as a computed property, since the computational complexity of this is relatively trivial.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might want to have zeroState (or maybe initialState) accept batchSize as the inputs are likely to be batched.

let forgetGate = sigmoid(matmul(gateInput, forgetWeight) + forgetBias)
let outputGate = sigmoid(matmul(gateInput, outputWeight))

let newCellState = (input.state.cell * forgetGate + inputGate * updateGate)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant parentheses.

self.forgetWeight = Tensor(glorotUniform: gateWeightShape)
self.forgetBias = Tensor(zeros: [Int32(hiddenSize)])
self.outputWeight = Tensor(glorotUniform: gateWeightShape)
self.stateShape = TensorShape([1, concatenatedInputSize])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always prefer literal conversion when a contextual type exists. In this case the contextual type is TensorShape and it conforms to ExpressibleByArrayLiteral.

Suggested change
self.stateShape = TensorShape([1, concatenatedInputSize])
self.stateShape = [1, concatenatedInputSize]

@tanmayb123
Copy link
Contributor Author

Love the sketch, Richard - thanks :)
Quick question: only LSTMs have a "cell state". Others, like GRU and SimpleRNN only have a "hidden state". Therefore, their "state" wouldn't be a struct, rather just a single Tensor. That's why I was a bit confused on how to create that protocol.

@eaplatanios
Copy link
Contributor

Sorry I wasn't able to respond earlier because I've been kept busy. I was thinking of something a bit more abstract where RNNState does not need to consist of two parts but is rather a generic type itself (called just State should be fine), as it can vary between different RNN cells.

@rxwei
Copy link
Contributor

rxwei commented Apr 7, 2019

Yeah, the cell state is definitely weird. I think that part may not need a fixed structure, like @eaplatanios said it can just be a generic type.

public struct RNNInput<StepInput: Differentiable, State: Differentiable>: Differentiable {
    public var stepInput: StepInput
    public var previousState: State
    public init(stepInput: StepInput, previousState: State) {
        self.stepInput = stepInput
        self.previousState = previousState
    }
}

public protocol RNNCell: Layer where Input == RNNInput<StepInput, State> {
    associatedtype StepInput: Differentiable
    typealias State = Output

    init(inputSize: Int, hiddenSize: Int)
    var zeroState: State { get }
}

public extension RNNCell {
    @differentiable
    func applied(to stepInput: StepInput, _ previousState: State) -> State {
        return applied(to: Input(stepInput: stepInput, previousState: previousState))
    }
}

@rxwei
Copy link
Contributor

rxwei commented Apr 7, 2019

I'm going to check the protocol in. If you can make these new layers work with the new protocol, that'd be really great!

@rxwei rxwei mentioned this pull request Apr 7, 2019
@tanmayb123
Copy link
Contributor Author

I'll do that shortly - thanks Richard!

return State(cell: newCellState, hidden: newHiddenState)
}

public func zeroState() -> State {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might want to have zeroState (or maybe initialState) accept batchSize as the inputs are likely to be batched.

self.inputWeight = Tensor(glorotUniform: gateWeightShape)
self.updateWeight = Tensor(glorotUniform: gateWeightShape)
self.forgetWeight = Tensor(glorotUniform: gateWeightShape)
self.forgetBias = Tensor(zeros: [Int32(hiddenSize)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A rule of a thumb is to initialize the forget bias with ~1, see http://proceedings.mlr.press/v37/jozefowicz15.pdf

Also, is there a reason for not having a bias for other gates?

@rxwei
Copy link
Contributor

rxwei commented Apr 16, 2019

Hi @tanmayb123, we are preparing a Swift for TensorFlow v0.3 release this Wednesday. Would you like to update this PR and check it in so that it can be part of the release?

@tanmayb123
Copy link
Contributor Author

Sure thing Richard, working on that now. Quick concern that I didn't notice before: RNNCellOutput has both an Output and State, but that's a bit problematic for the cells since they don't have any outputs per time-step - their output is the state. What should I do in that case? Should I pass the new state to Output and State?

@rxwei
Copy link
Contributor

rxwei commented Apr 17, 2019

I didn't see James's comment. Definitely address his concerns first :)

@tanmayb123
Copy link
Contributor Author

Also, related to this PR, I've opened #91 to see how to handle sequential inputs and keeping track of hidden states over time automatically.

@tanmayb123
Copy link
Contributor Author

The build is failing, but not on any of the Recurrent Cell code.

@rxwei
Copy link
Contributor

rxwei commented Apr 17, 2019

I'm going to merge it once tests pass.

@rxwei
Copy link
Contributor

rxwei commented Apr 17, 2019

Tests are failing because your branch is still old. Could you pull and merge?

@tanmayb123
Copy link
Contributor Author

Sure thing.

@rxwei rxwei merged commit 6dc373a into tensorflow:master Apr 17, 2019
@rxwei
Copy link
Contributor

rxwei commented Apr 17, 2019

Merged. Thanks for iterating on this!

@tanmayb123
Copy link
Contributor Author

Of course :)

dan-zheng added a commit that referenced this pull request Apr 17, 2019
dan-zheng added a commit that referenced this pull request Apr 17, 2019
This reverts commit 6dc373a.
It exposed a differentiation crash (TF-440) and is blocking progress.
dan-zheng added a commit that referenced this pull request Apr 18, 2019
rxwei pushed a commit that referenced this pull request Apr 18, 2019
* Revert "Revert "Add Recurrent Layers (#71)" (#94)"

This reverts commit f75c5e0.

* Remove `@differentiable` from `zeroState`.

* Fix axis of concatenation.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants