-
Notifications
You must be signed in to change notification settings - Fork 137
Conversation
public struct Flatten<Scalar: TensorFlowFloatingPoint>: Layer { | ||
@differentiable(wrt: (self, input)) | ||
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> { | ||
let batchSize = input.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR #22848 changed the behavior of differentiation through non-differentiable types, so it may break this function (requiring a .withoutDerivative()
call somewhere). Holding off until the next toolchain release (Monday) will be a good idea.
Also, a more efficient implementation would be to use .shapeTensor
whenever you need to get the shape from a tensor, and not use any values of TensorShape
type in the function body. On line 510, newShape
can be defined using Tensor
's concatenation operator, Tensor.concatenated(with:)
or ++
. (BTW, concatenation ops do not have a derivative defined yet, would you like to add one?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New toolchain released!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My recommendation would be:
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
let batchSize = input.shape[0]
let remaining = input.shape[1..<input.rank].contiguousSize
return input.reshaped(to: [batchSize, remaining])
}
But I see also @rxwei's point about preferring purely TF-side shape manipulation.
If we don't switch to that completely, though, we should replace reshaped(toShape: Tensor<Int32>(foo))
with reshaped(to: foo)
(assuming that actually works now).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following is the host-free version.
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
let batchSize = input.shapeTensor[0]
let remaining = input.shapeTensor[1..<input.rank].product()
return input.reshaped(toShape: Tensor([batchSize, remaining]))
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The host-free version is currently an order of magnitude slower on CPU than the hosty version (about 30us for mine and 300us for yours).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
D'oh that's terrible. Ok.
Given that triplet loss is not defined in Keras, we'd like to be careful about adding APIs without a precedent in other high level libraries. Could you please make this PR not depend on #31? Thanks! |
Hi @tanmayb123, I applied some simplifications and this PR is ready! We definitely want to make it into the v0.2 release, so I'm going to merge it now to get it in the toolchain. We appreciate your contributions! |
No description provided.