-
Notifications
You must be signed in to change notification settings - Fork 607
Commit ea27724
committed
Update on "[ET-VK][Ops] aten.embedding"
## The Operator
`nn.Module` invocations on the embedding returned by [`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) get compiled to `aten.embedding.default` in the Edge Dialect, which carries the following signature.
```
- func: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
```
## Implementation
This is a C-packing-only implementation.
Interestingly, the 1D-`indices` case is equivalent to the `dim=0` case of the preceding `aten.index_select`: #3744
```
- func: index_select(Tensor self, int dim, Tensor index) -> Tensor
```
I naïvely thought the rest of the operator would be similarly easy but it wasn't. The 2D and 3D-`indices` cases are more involved to the extent that we require a standalone `cpp`/`glsl` file.
## Codegen
We add support for making 2D and 3D index tensors. This requires new generation functions as well as renaming of the `case_name` string to recursively handle list `pylist`s.
```
// 1D
Test(weight=[10, 9], indices=[0, 2]),
// 2D
Test(weight=[10, 9], indices=[[0, 2], [1, 4], [7, 7]]),
// 3D
Test(weight=[10, 9], indices=[[[3, 1, 4], [1, 5, 9]], [[2, 6, 5], [3, 5, 8]]]),
```
Differential Revision: [D57880520](https://our.internmc.facebook.com/intern/diff/D57880520/)
[ghstack-poisoned]File tree
Expand file treeCollapse file tree
0 file changed
+0
-0
lines changedFilter options
Expand file treeCollapse file tree
0 file changed
+0
-0
lines changed
0 commit comments