Skip to content

Commit 7b4b2ae

Browse files
committed
Update on "[ET-VK][Ops] aten.embedding"
## The Operator `nn.Module` invocations on the embedding returned by [`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) get compiled to `aten.embedding.default` in the Edge Dialect, which carries the following signature. ``` - func: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor ``` ## Implementation This is a C-packing-only implementation. Interestingly, the 1D-`indices` case is equivalent to the `dim=0` case of the preceding `aten.index_select`: #3744 ``` - func: index_select(Tensor self, int dim, Tensor index) -> Tensor ``` I naïvely thought the rest of the operator would be similarly easy but it wasn't. The 2D and 3D-`indices` cases are more involved to the extent that we require a standalone `cpp`/`glsl` file. ## Codegen We add support for making 2D and 3D index tensors. This requires new generation functions as well as renaming of the `case_name` string to recursively handle list `pylist`s. ``` // 1D Test(weight=[10, 9], indices=[0, 2]), // 2D Test(weight=[10, 9], indices=[[0, 2], [1, 4], [7, 7]]), // 3D Test(weight=[10, 9], indices=[[[3, 1, 4], [1, 5, 9]], [[2, 6, 5], [3, 5, 8]]]), ``` Differential Revision: [D57880520](https://our.internmc.facebook.com/intern/diff/D57880520/) [ghstack-poisoned]
2 parents ea27724 + 8d8ce21 commit 7b4b2ae

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import re
8-
from typing import Any, List, Tuple
8+
from typing import Any, List
99

1010
from torchgen.api import cpp
1111
from torchgen.api.types import CppSignatureGroup

0 commit comments

Comments
 (0)