This repository was archived by the owner on Jul 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 137
Add an Embedding layer #257
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
1797941
Add an Embedding layer
jon-tow 2ad17bc
Merge branch 'master' of https://github.com/tensorflow/swift-apis int…
jon-tow b86539b
Update Sources/TensorFlow/Layers/Embedding.swift
jon-tow d7ba1bb
Update Sources/TensorFlow/Layers/Embedding.swift
jon-tow 8882081
Update Sources/TensorFlow/Layers/Embedding.swift
jon-tow b27b6cd
Rewrite struct documentation for conciseness.
jon-tow 634f200
Refactor doc example.
jon-tow 096c35f
Add proper documentation for implementation note.
jon-tow 0e1ac9a
Remove unnecessary backquotes in documentation.
jon-tow 3d83e8a
Add proper documentation for non-mutating function.
jon-tow 66b8a72
Add space between array entries.
jon-tow d3e4512
Add space between array entries.
jon-tow 7e01430
Remove argument label from Embedding initializer.
jon-tow b4ae209
Add proper documentation for return values.
jon-tow e06c0b5
Remove documentation for non-existent identifier.
jon-tow 2e47d49
Merge branch 'layer/embedding' of https://github.com/jon-tow/swift-ap…
jon-tow dcaf45c
Rename argument label to match implementation.
jon-tow fe5f88f
Add a top level EmbeddingInput along with tests for its correctness.
jon-tow File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
/// An input structure containing the embedding indices. | ||
/// | ||
/// - Note: Often times, `Embedding` is followed by a `Flatten` and a `Dense` layer. When this | ||
/// is the case, ensure that all input sequences of indices have the same dimension. | ||
// NOTE: This structure is needed to conform `Embedding` to the Layer protocol. | ||
@frozen | ||
public struct EmbeddingInput: Differentiable { | ||
/// Sequences of indices that will be passed into the layer. | ||
@noDerivative public var indices: Tensor<Int32> | ||
|
||
/// Creates an `EmbeddingInput`. | ||
/// | ||
/// - Parameter indices: The embedding indices. | ||
public init(indices: Tensor<Int32>) { | ||
self.indices = indices | ||
} | ||
} | ||
|
||
/// An embedding layer. | ||
/// | ||
/// `Embedding` is effectively a lookup table that maps indices from a fixed vocabulary to fixed-size | ||
/// (dense) vector representations, e.g. `[[0], [3]] -> [[0.25, 0.1], [0.6, -0.2]]`. | ||
public struct Embedding<Scalar: TensorFlowFloatingPoint>: Layer { | ||
/// A learnable lookup table that maps vocabulary indices to their dense vector representations. | ||
public var embeddings: Tensor<Scalar> | ||
|
||
/// Creates an `Embedding` layer with randomly initialized embeddings of shape | ||
/// `(vocabularySize, embeddingSize)` so that each vocabulary index is given a vector | ||
/// representation. | ||
/// | ||
/// - Parameters: | ||
/// - vocabularySize: The number of distinct indices (words) in the vocabulary. This number | ||
/// should be the largest integer index plus one. | ||
/// - embeddingSize: The number of entries in a single embedding vector representation. | ||
public init(vocabularySize: Int, embeddingSize: Int) { | ||
self.embeddings = Tensor(randomUniform: [vocabularySize, embeddingSize]) | ||
} | ||
|
||
/// Creates an `Embedding` layer from the provided embeddings. Useful for introducing | ||
/// pretrained embeddings into a model. | ||
/// | ||
/// - Parameter embeddings: The pretrained embeddings table. | ||
public init(embeddings: Tensor<Scalar>) { | ||
self.embeddings = embeddings | ||
} | ||
|
||
/// Returns an output by replacing each index in the input with corresponding dense vector representation. | ||
/// | ||
/// - Parameter | ||
/// - input: The indices that will be mapped to their vector representations. | ||
/// - Returns: The tensor created by replacing input indices with their vector representations. | ||
@differentiable | ||
public func callAsFunction(_ input: EmbeddingInput) -> Tensor<Scalar> { | ||
return embeddings.gathering(atIndices: input.indices) | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's absolutely possible to have an embedding layer with more than 2^32 embeddings (and therefore one that requires
Int64
indices), but it's fine to not worry about that for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gathering(atIndices:alongAxis:)
is not yet generic overBinaryInteger
. To fix this, we need to start from there.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#259