Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add an Embedding layer #257

Merged
merged 18 commits into from
Jun 18, 2019
Merged

Add an Embedding layer #257

merged 18 commits into from
Jun 18, 2019

Conversation

jon-tow
Copy link
Contributor

@jon-tow jon-tow commented Jun 18, 2019

Add support for embeddings.

#54

It might be a good idea to find a type alias for this layer to avoid having to supply a useless generic type argument, e.g.Embedding<Float>, since a TensorFlowFloatingPoint is not the true type (Int) of input into the layer. Maybe Swift will soon implement default generic parameters?

@rxwei rxwei requested review from dan-zheng, rxwei and jekbradbury June 18, 2019 19:17
jon-tow and others added 6 commits June 18, 2019 15:34
Add space between colon and protocol name.

Co-Authored-By: Richard Wei <[email protected]>
Correct misspelling.

Co-Authored-By: Richard Wei <[email protected]>
Remove unneeded line break in documentation.

Co-Authored-By: Richard Wei <[email protected]>
Co-Authored-By: Richard Wei <[email protected]>
@rxwei
Copy link
Contributor

rxwei commented Jun 18, 2019

Maybe Swift will soon implement default generic parameters?

Yep, the Swift for TensorFlow team plans to implement default generic arguments at some point if it's not yet done by the Swift community.

Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really great. Thanks for putting this together!

@jon-tow
Copy link
Contributor Author

jon-tow commented Jun 18, 2019

@rxwei Glad to help out! Sorry for all the tiny mistakes. I'll be much more careful on any future PRs I submit.

@frozen
public struct EmbeddingInput: Differentiable {
/// Sequences of indices that will be passed into the layer.
@noDerivative public var indices: Tensor<Int32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's absolutely possible to have an embedding layer with more than 2^32 embeddings (and therefore one that requires Int64 indices), but it's fine to not worry about that for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gathering(atIndices:alongAxis:) is not yet generic over BinaryInteger. To fix this, we need to start from there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxwei rxwei merged commit 760fafc into tensorflow:master Jun 18, 2019
@jon-tow jon-tow deleted the layer/embedding branch June 18, 2019 22:03
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants