Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add flatten & reshape layers #32

Merged
merged 8 commits into from
Feb 28, 2019
Merged

Add flatten & reshape layers #32

merged 8 commits into from
Feb 28, 2019

Conversation

tanmayb123
Copy link
Contributor

@tanmayb123 tanmayb123 commented Feb 25, 2019

No description provided.

public struct Flatten<Scalar: TensorFlowFloatingPoint>: Layer {
@differentiable(wrt: (self, input))
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
let batchSize = input.shape[0]
Copy link
Contributor

@rxwei rxwei Feb 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR #22848 changed the behavior of differentiation through non-differentiable types, so it may break this function (requiring a .withoutDerivative() call somewhere). Holding off until the next toolchain release (Monday) will be a good idea.

Also, a more efficient implementation would be to use .shapeTensor whenever you need to get the shape from a tensor, and not use any values of TensorShape type in the function body. On line 510, newShape can be defined using Tensor's concatenation operator, Tensor.concatenated(with:) or ++. (BTW, concatenation ops do not have a derivative defined yet, would you like to add one?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New toolchain released!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My recommendation would be:

public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
    let batchSize = input.shape[0]
    let remaining = input.shape[1..<input.rank].contiguousSize
    return input.reshaped(to: [batchSize, remaining])
}

But I see also @rxwei's point about preferring purely TF-side shape manipulation.

If we don't switch to that completely, though, we should replace reshaped(toShape: Tensor<Int32>(foo)) with reshaped(to: foo) (assuming that actually works now).

Copy link
Contributor

@rxwei rxwei Feb 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following is the host-free version.

public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
    let batchSize = input.shapeTensor[0]
    let remaining = input.shapeTensor[1..<input.rank].product()
    return input.reshaped(toShape: Tensor([batchSize, remaining]))
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The host-free version is currently an order of magnitude slower on CPU than the hosty version (about 30us for mine and 300us for yours).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

D'oh that's terrible. Ok.

@rxwei
Copy link
Contributor

rxwei commented Feb 25, 2019

Given that triplet loss is not defined in Keras, we'd like to be careful about adding APIs without a precedent in other high level libraries. Could you please make this PR not depend on #31? Thanks!

@rxwei
Copy link
Contributor

rxwei commented Feb 28, 2019

Hi @tanmayb123, I applied some simplifications and this PR is ready! We definitely want to make it into the v0.2 release, so I'm going to merge it now to get it in the toolchain. We appreciate your contributions!

@rxwei rxwei merged commit 6f5b962 into tensorflow:master Feb 28, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants