Skip to content

Commit c072a18

Browse files
committed
[ET-VK][Ops] aten.embedding
## The Operator `nn.Module` invocations on the embedding returned by [`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) get compiled to `aten.embedding.default` in the Edge Dialect, which carries the following signature. ``` - func: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor ``` ## Implementation This is a C-packing-only implementation. Interestingly, the 1D-`indices` case is equivalent to the `dim=0` case of the preceding `aten.index_select`: #3744 ``` - func: index_select(Tensor self, int dim, Tensor index) -> Tensor ``` I naïvely thought the rest of the operator would be similarly easy but it wasn't. The 2D and 3D-`indices` cases are more involved to the extent that we require a standalone `cpp`/`glsl` file. ## Codegen We add support for making 2D and 3D index tensors. This requires new generation functions as well as renaming of the `case_name` string to recursively handle list `pylist`s. ``` // 1D Test(weight=[10, 9], indices=[0, 2]), // 2D Test(weight=[10, 9], indices=[[0, 2], [1, 4], [7, 7]]), // 3D Test(weight=[10, 9], indices=[[[3, 1, 4], [1, 5, 9]], [[2, 6, 5], [3, 5, 8]]]), ``` Differential Revision: [D57880520](https://our.internmc.facebook.com/intern/diff/D57880520/) ghstack-source-id: 228038402 Pull Request resolved: #3762
1 parent 5e3aa82 commit c072a18

File tree

7 files changed

+255
-14
lines changed

7 files changed

+255
-14
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __contains__(self, op):
9999
]
100100

101101
INDEXING_OPS = [
102+
exir_ops.edge.aten.embedding.default,
102103
exir_ops.edge.aten.index_select.default,
103104
exir_ops.edge.aten.select_copy.int,
104105
exir_ops.edge.aten.slice_copy.Tensor,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_tensor(1, "r", "t_in", "int", STORAGE)}
21+
${layout_declare_tensor(2, "r", "t_weight", DTYPE, STORAGE)}
22+
${layout_declare_ubo(3, "ivec4", "sizes")}
23+
24+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
25+
26+
layout(constant_id = 3) const int packed_dim = C_DIM;
27+
28+
void main() {
29+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
30+
31+
if (pos_out_of_bounds(out_pos, sizes, packed_dim)) {
32+
return;
33+
}
34+
35+
const ivec4 out_idx = to_tensor_idx(out_pos, sizes, packed_dim);
36+
VEC4_T out_texel;
37+
38+
// Consider optimizing via W-packing format for t_in and t_weight.
39+
for (int i = 0; i < 4; ++i) {
40+
// Read input tensor for embedding index.
41+
const ivec3 in_pos = ivec3(out_pos.y, out_idx.z * 4 + i, out_idx.w / 4);
42+
const int in_texel = texelFetch(t_in, in_pos, 0)[out_idx.w % 4];
43+
44+
// Read weight tensor for embedding.
45+
out_texel[i] = texelFetch(t_weight, ivec3(out_pos.x, in_texel, 0), 0).x;
46+
}
47+
48+
imageStore(t_out, out_pos, out_texel);
49+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
embedding:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
STORAGE: texture3d
6+
generate_variant_forall:
7+
DTYPE:
8+
- VALUE: half
9+
- VALUE: float
10+
- VALUE: int
11+
shader_variants:
12+
- NAME: embedding
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
void check_embedding_args(
21+
const vTensor& weight,
22+
const vTensor& in,
23+
const vTensor& out) {
24+
VK_CHECK_COND(check_memory_layout_is(weight, api::kChannelsPacked));
25+
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
26+
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
27+
}
28+
29+
void add_embedding_node(
30+
ComputeGraph& graph,
31+
ValueRef weight,
32+
ValueRef in,
33+
ValueRef out) {
34+
vTensorPtr t_weight = graph.get_tensor(weight);
35+
vTensorPtr t_in = graph.get_tensor(in);
36+
vTensorPtr t_out = graph.get_tensor(out);
37+
38+
check_embedding_args(*t_weight, *t_in, *t_out);
39+
40+
std::string kernel_name = "embedding";
41+
kernel_name.reserve(kShaderNameReserve);
42+
add_dtype_suffix(kernel_name, *t_out);
43+
44+
api::utils::uvec3 global_size = t_out->image_extents();
45+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
46+
47+
graph.execute_nodes().emplace_back(new ExecuteNode(
48+
graph,
49+
VK_KERNEL_FROM_STR(kernel_name),
50+
global_size,
51+
local_size,
52+
{{out, api::MemoryAccessType::WRITE},
53+
{{in, weight}, api::MemoryAccessType::READ}},
54+
{t_out->sizes_ubo()}));
55+
}
56+
57+
void embedding(ComputeGraph& graph, const std::vector<ValueRef>& args) {
58+
ValueRef weight = prepack_if_tensor_ref(graph, args[0]);
59+
ValueRef in = prepack_if_tensor_ref(graph, args[1]);
60+
ValueRef out = args[5];
61+
62+
add_embedding_node(graph, weight, in, out);
63+
}
64+
65+
REGISTER_OPERATORS {
66+
VK_REGISTER_OP(aten.embedding.default, embedding);
67+
}
68+
69+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,25 @@ def get_index_select_inputs():
423423
return test_suite
424424

425425

426+
def get_embedding_inputs():
427+
Test = namedtuple("VkEmbeddingTest", ["weight", "indices"])
428+
Test.__new__.__defaults__ = (None, None)
429+
430+
test_cases = [
431+
Test(weight=[10, 9], indices=[0, 2]),
432+
Test(weight=[10, 9], indices=[2, 3, 4, 5, 7]),
433+
Test(weight=[10, 9], indices=[[0, 2], [1, 4], [7, 7]]),
434+
Test(weight=[10, 9], indices=[[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]),
435+
Test(weight=[10, 9], indices=[[[3, 1, 4], [1, 5, 9]], [[2, 6, 5], [3, 5, 8]]]),
436+
]
437+
438+
test_suite = VkTestSuite([tuple(tc) + (-1, "false", "false") for tc in test_cases])
439+
440+
test_suite.dtypes = ["at::kFloat"]
441+
test_suite.layouts = ["api::kChannelsPacked"]
442+
return test_suite
443+
444+
426445
def get_unsqueeze_inputs():
427446
test_suite = VkTestSuite(
428447
[
@@ -817,6 +836,7 @@ def get_gelu_inputs():
817836
"aten.slice_copy.Tensor": get_slice_inputs(),
818837
"aten.slice.Tensor": get_slice_inputs(),
819838
"aten.index_select.default": get_index_select_inputs(),
839+
"aten.embedding.default": get_embedding_inputs(),
820840
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
821841
"aten.clone.default": get_clone_inputs(),
822842
"aten.repeat.default": get_repeat_inputs(),

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,14 @@ def init_list_str(pylist: Any) -> str:
7878
if not isinstance(pylist, (list, tuple)):
7979
pylist = [pylist]
8080

81-
init_list_str = "{"
81+
list_str = "{"
8282
for s in pylist:
83-
init_list_str += f"{s}, "
84-
init_list_str = init_list_str[:-2] + "}"
85-
return init_list_str
83+
if isinstance(s, (list, tuple)):
84+
list_str += f"{init_list_str(s)}, "
85+
else:
86+
list_str += f"{s}, "
87+
list_str = list_str[:-2] + "}"
88+
return list_str
8689

8790

8891
def get_or_return_default(arg: Argument, inputs: List[Any], i: int):
@@ -105,8 +108,17 @@ def __init__(self, f: NativeFunction, test_suite: TestSuite):
105108
self.f, method=False, fallback_binding=self.f.manual_cpp_binding
106109
).most_faithful_signature()
107110

108-
def gen_case_name_tuple(self, t: Tuple) -> str:
109-
return "x".join([str(e) for e in t])
111+
def gen_case_name_tuple(self, t) -> str:
112+
return "x".join(
113+
[
114+
(
115+
str(e)
116+
if not isinstance(e, (list, tuple))
117+
else self.gen_case_name_tuple(e)
118+
)
119+
for e in t
120+
]
121+
)
110122

111123
def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str:
112124
name_str = self.op_name
@@ -119,7 +131,7 @@ def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str:
119131
elif isinstance(arg_sizes_or_val, list):
120132
lst = []
121133
for size in arg_sizes_or_val:
122-
if isinstance(size, tuple):
134+
if isinstance(size, (list, tuple)):
123135
lst.append(self.gen_case_name_tuple(size))
124136
else:
125137
lst.append(str(size))
@@ -154,7 +166,7 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
154166
ret_str = f"{cpp_type} {arg.name} = "
155167

156168
if cpp_type == AT_TENSOR:
157-
if arg.name == "index":
169+
if arg.name == "index" or arg.name == "indices":
158170
ret_str += f"make_index_tensor({init_list_str(data)});"
159171
else:
160172
ret_str += (
@@ -283,19 +295,52 @@ def generate_suite_cpp(self) -> str:
283295
values[i] = (float) i;
284296
}}
285297
286-
// from_blob doesn't take ownership of data. Hence must create a copy as
287-
// "values" will go out of scope.
298+
// Clone as original data will be deallocated upon return.
288299
return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone();
289300
}}
290301
291302
292303
at::Tensor make_index_tensor(std::vector<int64_t> indices) {{
293-
int64_t size = static_cast<int64_t>(indices.size());
294304
at::ScalarType dtype = at::kInt;
305+
std::vector<int64_t> sizes = {{static_cast<int64_t>(indices.size())}};
306+
307+
// Clone as original data will be deallocated upon return.
308+
return at::from_blob(indices.data(), sizes, dtype).detach().clone();
309+
}}
310+
311+
at::Tensor make_index_tensor(std::vector<std::vector<int64_t>> indices) {{
312+
at::ScalarType dtype = at::kInt;
313+
std::vector<int64_t> sizes = {{
314+
static_cast<int64_t>(indices.size()),
315+
static_cast<int64_t>(indices[0].size())}};
316+
317+
// Flatten indices as from_blob reads garbage otherwise.
318+
std::vector<int64_t> acc;
319+
for (auto& vec: indices) {{
320+
acc.insert(acc.end(), vec.begin(), vec.end());
321+
}}
322+
323+
// Clone as original data will be deallocated upon return.
324+
return at::from_blob(acc.data(), sizes, dtype).detach().clone();
325+
}}
326+
327+
at::Tensor make_index_tensor(std::vector<std::vector<std::vector<int64_t>>> indices) {{
328+
at::ScalarType dtype = at::kInt;
329+
std::vector<int64_t> sizes = {{
330+
static_cast<int64_t>(indices.size()),
331+
static_cast<int64_t>(indices[0].size()),
332+
static_cast<int64_t>(indices[0][0].size())}};
333+
334+
// Flatten indices as from_blob reads garbage otherwise.
335+
std::vector<int64_t> acc;
336+
for (auto& v: indices) {{
337+
for (auto& vv: v) {{
338+
acc.insert(acc.end(), vv.begin(), vv.end());
339+
}}
340+
}}
295341
296-
// from_blob doesn't take ownership of data. Hence must create a copy as
297-
// "values" will go out of scope.
298-
return at::from_blob(indices.data(), {{size}}, dtype).detach().clone();
342+
// Clone as original data will be deallocated upon return.
343+
return at::from_blob(acc.data(), sizes, dtype).detach().clone();
299344
}}
300345
301346
{test_suites_cpp}

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,3 +1291,48 @@ def forward(self, x):
12911291
sample_inputs,
12921292
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
12931293
)
1294+
1295+
def test_vulkan_backend_embedding_1d(self):
1296+
class EmbeddingModule(torch.nn.Module):
1297+
def __init__(self, embedding):
1298+
super().__init__()
1299+
self.embedding = embedding
1300+
1301+
def forward(self, x):
1302+
return self.embedding(x)
1303+
1304+
self.lower_module_and_test_output(
1305+
EmbeddingModule(torch.nn.Embedding(4, 5)),
1306+
(torch.tensor([0, 1, 0, 4, 2, 0], dtype=torch.int32),),
1307+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1308+
)
1309+
1310+
def test_vulkan_backend_embedding_2d(self):
1311+
class EmbeddingModule(torch.nn.Module):
1312+
def __init__(self, embedding):
1313+
super().__init__()
1314+
self.embedding = embedding
1315+
1316+
def forward(self, x):
1317+
return self.embedding(x)
1318+
1319+
self.lower_module_and_test_output(
1320+
EmbeddingModule(torch.nn.Embedding(4, 5)),
1321+
(torch.tensor([[0, 1, 0], [4, 2, 0]], dtype=torch.int32),),
1322+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1323+
)
1324+
1325+
def test_vulkan_backend_embedding_3d(self):
1326+
class EmbeddingModule(torch.nn.Module):
1327+
def __init__(self, embedding):
1328+
super().__init__()
1329+
self.embedding = embedding
1330+
1331+
def forward(self, x):
1332+
return self.embedding(x)
1333+
1334+
self.lower_module_and_test_output(
1335+
EmbeddingModule(torch.nn.Embedding(4, 5)),
1336+
(torch.tensor([[[0, 1], [0, 1]], [[4, 2], [3, 3]]], dtype=torch.int32),),
1337+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1338+
)

0 commit comments

Comments
 (0)