Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 760fafc

Browse files
jon-towrxwei
authored andcommitted
Add an Embedding layer (#257)
1 parent e6f2e04 commit 760fafc

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/// An input structure containing the embedding indices.
16+
///
17+
/// - Note: Often times, `Embedding` is followed by a `Flatten` and a `Dense` layer. When this
18+
/// is the case, ensure that all input sequences of indices have the same dimension.
19+
// NOTE: This structure is needed to conform `Embedding` to the Layer protocol.
20+
@frozen
21+
public struct EmbeddingInput: Differentiable {
22+
/// Sequences of indices that will be passed into the layer.
23+
@noDerivative public var indices: Tensor<Int32>
24+
25+
/// Creates an `EmbeddingInput`.
26+
///
27+
/// - Parameter indices: The embedding indices.
28+
public init(indices: Tensor<Int32>) {
29+
self.indices = indices
30+
}
31+
}
32+
33+
/// An embedding layer.
34+
///
35+
/// `Embedding` is effectively a lookup table that maps indices from a fixed vocabulary to fixed-size
36+
/// (dense) vector representations, e.g. `[[0], [3]] -> [[0.25, 0.1], [0.6, -0.2]]`.
37+
public struct Embedding<Scalar: TensorFlowFloatingPoint>: Layer {
38+
/// A learnable lookup table that maps vocabulary indices to their dense vector representations.
39+
public var embeddings: Tensor<Scalar>
40+
41+
/// Creates an `Embedding` layer with randomly initialized embeddings of shape
42+
/// `(vocabularySize, embeddingSize)` so that each vocabulary index is given a vector
43+
/// representation.
44+
///
45+
/// - Parameters:
46+
/// - vocabularySize: The number of distinct indices (words) in the vocabulary. This number
47+
/// should be the largest integer index plus one.
48+
/// - embeddingSize: The number of entries in a single embedding vector representation.
49+
public init(vocabularySize: Int, embeddingSize: Int) {
50+
self.embeddings = Tensor(randomUniform: [vocabularySize, embeddingSize])
51+
}
52+
53+
/// Creates an `Embedding` layer from the provided embeddings. Useful for introducing
54+
/// pretrained embeddings into a model.
55+
///
56+
/// - Parameter embeddings: The pretrained embeddings table.
57+
public init(embeddings: Tensor<Scalar>) {
58+
self.embeddings = embeddings
59+
}
60+
61+
/// Returns an output by replacing each index in the input with corresponding dense vector representation.
62+
///
63+
/// - Parameter
64+
/// - input: The indices that will be mapped to their vector representations.
65+
/// - Returns: The tensor created by replacing input indices with their vector representations.
66+
@differentiable
67+
public func callAsFunction(_ input: EmbeddingInput) -> Tensor<Scalar> {
68+
return embeddings.gathering(atIndices: input.indices)
69+
}
70+
}

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,23 @@ final class LayerTests: XCTestCase {
205205
XCTAssertEqual(output.shape, expected)
206206
}
207207

208+
func testEmbedding() {
209+
var layer = Embedding<Float>(vocabularySize: 3, embeddingSize: 5)
210+
var data = Tensor<Int32>(shape: [2, 3], scalars: [0, 1, 2, 1, 2, 2])
211+
var input = EmbeddingInput(indices: data)
212+
var output = layer.inferring(from: input)
213+
let expectedShape = TensorShape([2, 3, 5])
214+
XCTAssertEqual(output.shape, expectedShape)
215+
216+
let pretrained = Tensor<Float>(shape:[2, 2], scalars: [0.4, 0.3, 0.2, 0.1])
217+
layer = Embedding<Float>(embeddings: pretrained)
218+
data = Tensor<Int32>(shape: [2, 2], scalars: [0, 1, 1, 1])
219+
input = EmbeddingInput(indices: data)
220+
output = layer.inferring(from: input)
221+
let expected = Tensor<Float>([[[0.4, 0.3], [0.2, 0.1]], [[0.2, 0.1],[0.2, 0.1]]])
222+
XCTAssertEqual(output, expected)
223+
}
224+
208225
func testSimpleRNNCell() {
209226
let weight = Tensor<Float>(ones: [7, 5]) * Tensor<Float>([0.3333, 1, 0.3333, 1, 0.3333])
210227
let bias = Tensor<Float>(ones: [5])
@@ -272,6 +289,7 @@ final class LayerTests: XCTestCase {
272289
("testUpSampling3D", testUpSampling3D),
273290
("testReshape", testReshape),
274291
("testFlatten", testFlatten),
292+
("testEmbedding", testEmbedding),
275293
("testSimpleRNNCell", testSimpleRNNCell),
276294
("testRNN", testRNN)
277295
]

0 commit comments

Comments
 (0)