-
Notifications
You must be signed in to change notification settings - Fork 137
Conversation
Would we prefer to have |
The full names |
Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a test to Tests/DeepLearningTests/LayerTests.swift
!
It would also be nice if we can have an RNN cell protocol, so that all valid cells have to conform to it and can be used in code that expects RNN cells (e.g., simple RNN, bi-directional RNN, etc). |
@rxwei What do you think such a protocol would look like? i.e. What kind of functionality would it define? |
Just a heads-up: I probably won't have time to review this or make suggestions until late evening. |
Here's a sketch: public struct RNNInput<StepInput: Differentiable, State: RNNState>: Differentiable {
public var stepInput: StepInput
public var previousState: State
public init(stepInput: StepInput, previousState: State) {
self.stepInput = stepInput
self.previousState = previousState
}
}
public struct RNNState<CellState: Differentiable, HiddenState: Differentiable>: Differentiable {
public var cell: CellState
public var hidden: HiddenState
public init(cell: CellState, hidden: HiddenState) {
self.cell = cell
self.hidden = hidden
}
}
public protocol RNNCell: Layer
where Input == RNNInput<StepInput, State>, Output == RNNState<CellState, HiddenState> {
associatedtype StepInput: Differentiable
associatedtype CellState: Differentiable
associatedtype HiddenState: Differentiable
typealias State = Output
init(inputSize: Int, hiddenSize: Int)
var zeroState: State { get }
}
public extension RNNCell {
@differentiable
func applied(to stepInput: StepInput, _ previousState: State) -> State {
return applied(to: Input(stepInput: stepInput, previousState: previousState))
}
} |
Sources/DeepLearning/Layer.swift
Outdated
return State(cell: newCellState, hidden: newHiddenState) | ||
} | ||
|
||
public func zeroState() -> State { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be more Swifty to define this as a computed property, since the computational complexity of this is relatively trivial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you might want to have zeroState
(or maybe initialState
) accept batchSize
as the inputs are likely to be batched.
Sources/DeepLearning/Layer.swift
Outdated
let forgetGate = sigmoid(matmul(gateInput, forgetWeight) + forgetBias) | ||
let outputGate = sigmoid(matmul(gateInput, outputWeight)) | ||
|
||
let newCellState = (input.state.cell * forgetGate + inputGate * updateGate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove redundant parentheses.
Sources/DeepLearning/Layer.swift
Outdated
self.forgetWeight = Tensor(glorotUniform: gateWeightShape) | ||
self.forgetBias = Tensor(zeros: [Int32(hiddenSize)]) | ||
self.outputWeight = Tensor(glorotUniform: gateWeightShape) | ||
self.stateShape = TensorShape([1, concatenatedInputSize]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always prefer literal conversion when a contextual type exists. In this case the contextual type is TensorShape
and it conforms to ExpressibleByArrayLiteral
.
self.stateShape = TensorShape([1, concatenatedInputSize]) | |
self.stateShape = [1, concatenatedInputSize] |
Love the sketch, Richard - thanks :) |
Sorry I wasn't able to respond earlier because I've been kept busy. I was thinking of something a bit more abstract where |
Yeah, the cell state is definitely weird. I think that part may not need a fixed structure, like @eaplatanios said it can just be a generic type. public struct RNNInput<StepInput: Differentiable, State: Differentiable>: Differentiable {
public var stepInput: StepInput
public var previousState: State
public init(stepInput: StepInput, previousState: State) {
self.stepInput = stepInput
self.previousState = previousState
}
}
public protocol RNNCell: Layer where Input == RNNInput<StepInput, State> {
associatedtype StepInput: Differentiable
typealias State = Output
init(inputSize: Int, hiddenSize: Int)
var zeroState: State { get }
}
public extension RNNCell {
@differentiable
func applied(to stepInput: StepInput, _ previousState: State) -> State {
return applied(to: Input(stepInput: stepInput, previousState: previousState))
}
} |
I'm going to check the protocol in. If you can make these new layers work with the new protocol, that'd be really great! |
I'll do that shortly - thanks Richard! |
Sources/DeepLearning/Layer.swift
Outdated
return State(cell: newCellState, hidden: newHiddenState) | ||
} | ||
|
||
public func zeroState() -> State { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you might want to have zeroState
(or maybe initialState
) accept batchSize
as the inputs are likely to be batched.
Sources/DeepLearning/Layer.swift
Outdated
self.inputWeight = Tensor(glorotUniform: gateWeightShape) | ||
self.updateWeight = Tensor(glorotUniform: gateWeightShape) | ||
self.forgetWeight = Tensor(glorotUniform: gateWeightShape) | ||
self.forgetBias = Tensor(zeros: [Int32(hiddenSize)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A rule of a thumb is to initialize the forget bias with ~1, see http://proceedings.mlr.press/v37/jozefowicz15.pdf
Also, is there a reason for not having a bias for other gates?
Hi @tanmayb123, we are preparing a Swift for TensorFlow v0.3 release this Wednesday. Would you like to update this PR and check it in so that it can be part of the release? |
Sure thing Richard, working on that now. Quick concern that I didn't notice before: |
I didn't see James's comment. Definitely address his concerns first :) |
Co-Authored-By: tanmayb123 <[email protected]>
Also, related to this PR, I've opened #91 to see how to handle sequential inputs and keeping track of hidden states over time automatically. |
The build is failing, but not on any of the Recurrent Cell code. |
I'm going to merge it once tests pass. |
Tests are failing because your branch is still old. Could you pull and merge? |
Sure thing. |
Merged. Thanks for iterating on this! |
Of course :) |
This reverts commit 6dc373a.
#52