Skip to content

Commit 9258781

Browse files
committed
Update base for Update on "[ET-VK][Ops] aten.embedding"
## The Operator `nn.Module` invocations on the embedding returned by [`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) get compiled to `aten.embedding.default` in the Edge Dialect, which carries the following signature. ``` - func: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor ``` ## Implementation This is a C-packing-only implementation. Interestingly, the 1D-`indices` case is equivalent to the `dim=0` case of the preceding `aten.index_select`: #3744 ``` - func: index_select(Tensor self, int dim, Tensor index) -> Tensor ``` I naïvely thought the rest of the operator would be similarly easy but it wasn't. The 2D and 3D-`indices` cases are more involved to the extent that we require a standalone `cpp`/`glsl` file. ## Codegen We add support for making 2D and 3D index tensors. This requires new generation functions as well as renaming of the `case_name` string to recursively handle list `pylist`s. ``` // 1D Test(weight=[10, 9], indices=[0, 2]), // 2D Test(weight=[10, 9], indices=[[0, 2], [1, 4], [7, 7]]), // 3D Test(weight=[10, 9], indices=[[[3, 1, 4], [1, 5, 9]], [[2, 6, 5], [3, 5, 8]]]), ``` Differential Revision: [D57880520](https://our.internmc.facebook.com/intern/diff/D57880520/) [ghstack-poisoned]
2 parents a208948 + 55d11e1 commit 9258781

31 files changed

+1051
-159
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __contains__(self, op):
112112
]
113113

114114
CREATION_OPS = [
115+
exir_ops.edge.aten.arange.start_step,
115116
exir_ops.edge.aten.clone.default,
116117
exir_ops.edge.aten.full.default,
117118
]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_ubo(1, "ivec4", "sizes")}
21+
${layout_declare_ubo(2, "float", "start")}
22+
${layout_declare_ubo(3, "float", "step")}
23+
24+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
25+
26+
layout(constant_id = 3) const int packed_dim = C_DIM;
27+
28+
void main() {
29+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
30+
const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim);
31+
32+
if (pos_out_of_bounds(pos, sizes, packed_dim)) {
33+
return;
34+
}
35+
36+
VEC4_T outtex = VEC4_T(start + pos.x * step, 0, 0, 0);
37+
38+
imageStore(t_out, pos, outtex);
39+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
arange:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: int
11+
STORAGE: texture3d
12+
PACKING: C_packed
13+
generate_variant_forall:
14+
DTYPE:
15+
- VALUE: half
16+
- VALUE: float
17+
- VALUE: int
18+
shader_variants:
19+
- NAME: arange

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,14 @@
1919

2020
layout(std430) buffer;
2121

22-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
23-
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
24-
layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other;
25-
26-
layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes {
27-
ivec4 out_sizes;
28-
};
29-
30-
layout(set = 0, binding = 4) uniform PRECISION restrict InSizes {
31-
ivec4 in_sizes;
32-
};
33-
34-
layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes {
35-
ivec4 other_sizes;
36-
};
37-
38-
layout(set = 0, binding = 6) uniform PRECISION restrict BroadcastParams {
39-
ivec2 broadcast_params;
40-
};
41-
42-
layout(set = 0, binding = 7) uniform PRECISION restrict Alpha {
43-
float alpha;
44-
};
22+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
23+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
24+
${layout_declare_tensor(2, "r", "t_other", DTYPE, STORAGE)}
25+
${layout_declare_ubo(3, "ivec4", "out_sizes")}
26+
${layout_declare_ubo(4, "ivec4", "in_sizes")}
27+
${layout_declare_ubo(5, "ivec4", "other_sizes")}
28+
${layout_declare_ubo(6, "ivec2", "broadcast_params")}
29+
${layout_declare_ubo(7, "float", "alpha")}
4530

4631
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4732

@@ -57,13 +42,13 @@ void main() {
5742

5843
ivec4 in_idx = broadcast_indices(idx, in_sizes);
5944
VEC4_T in_texel = VEC4_T(texelFetch(
60-
image_in,
45+
t_in,
6146
to_texture_pos(in_idx, in_sizes, packed_dim),
6247
0));
6348

6449
ivec4 other_idx = broadcast_indices(idx, other_sizes);
6550
VEC4_T other_texel = VEC4_T(texelFetch(
66-
image_other,
51+
t_other,
6752
to_texture_pos(other_idx, other_sizes, packed_dim),
6853
0));
6954

@@ -75,5 +60,5 @@ void main() {
7560
other_texel = other_texel.xxxx;
7661
}
7762

78-
imageStore(image_out, pos, VEC4_T(op(in_texel, other_texel, alpha)));
63+
imageStore(t_out, pos, VEC4_T(op(in_texel, other_texel, alpha)));
7964
}

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ binary_op:
1010
NDIM: 3
1111
DTYPE: float
1212
PACKING: C_packed
13+
STORAGE: texture3d
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: half

backends/vulkan/runtime/graph/ops/glsl/full.glsl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15-
#include "broadcasting_utils.h"
1615
#include "indexing_utils.h"
1716

1817
layout(std430) buffer;

backends/vulkan/runtime/graph/ops/glsl/upsample.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.glsl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,15 @@
1313

1414
#define PRECISION ${PRECISION}
1515

16-
#define VEC4_T ${texel_type(DTYPE)}
16+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
1717

1818
layout(std430) buffer;
1919

20-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
21-
22-
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
23-
24-
layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits {
25-
ivec3 out_limits;
26-
};
27-
28-
layout(set = 0, binding = 3) uniform PRECISION restrict Sizes {
29-
ivec4 sizes;
30-
};
20+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
21+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
22+
${layout_declare_ubo(2, "ivec3", "out_limits")}
23+
${layout_declare_ubo(3, "ivec2", "input_size")}
24+
${layout_declare_ubo(4, "vec2", "rev_scales")}
3125

3226
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3327

@@ -38,6 +32,8 @@ void main() {
3832
return;
3933
}
4034

41-
VEC4_T in_texel = texelFetch(image_in, pos, 0);
42-
imageStore(image_out, pos, in_texel);
35+
const ivec2 ipos = clamp(ivec2(pos.xy * rev_scales), ivec2(0), input_size);
36+
37+
VEC4_T in_texel = texelFetch(t_in, ivec3(ipos, pos.z), 0);
38+
imageStore(t_out, pos, in_texel);
4339
}

backends/vulkan/runtime/graph/ops/glsl/upsample.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
upsample:
7+
upsample_nearest2d:
88
parameter_names_with_default_values:
99
NDIM: 3
1010
DTYPE: float
1111
PACKING: C_packed
12+
STORAGE: texture3d
1213
generate_variant_forall:
1314
DTYPE:
1415
- VALUE: half
1516
- VALUE: float
1617
shader_variants:
17-
- NAME: upsample
18+
- NAME: upsample_nearest2d
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/api/Utils.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
14+
15+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
16+
17+
namespace vkcompute {
18+
19+
void resize_arange_node(
20+
ComputeGraph* graph,
21+
const std::vector<ArgGroup>& args,
22+
const std::vector<ValueRef>& extra_args) {
23+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
24+
25+
int start_val = 0;
26+
int step_val = 1;
27+
if (!graph->val_is_none(extra_args[0])) {
28+
start_val = graph->extract_scalar<int64_t>(extra_args[0]);
29+
}
30+
int end_val = graph->extract_scalar<int64_t>(extra_args[1]);
31+
if (!graph->val_is_none(extra_args[2])) {
32+
step_val = graph->extract_scalar<int64_t>(extra_args[2]);
33+
}
34+
35+
std::vector<int64_t> out_sizes = {
36+
api::utils::div_up(end_val - start_val, step_val)};
37+
38+
out->virtual_resize(out_sizes);
39+
}
40+
41+
void check_arange_input(
42+
ComputeGraph& graph,
43+
const ValueRef start,
44+
const ValueRef end,
45+
const ValueRef step) {
46+
if (!graph.val_is_none(start) && !graph.val_is_int(end)) {
47+
VK_THROW("arange: start must be int!");
48+
}
49+
if (!graph.val_is_none(end) && !graph.val_is_int(end)) {
50+
VK_THROW("arange: end must be int!");
51+
}
52+
if (!graph.val_is_none(step) && !graph.val_is_int(end)) {
53+
VK_THROW("arange: step must be int!");
54+
}
55+
}
56+
57+
void add_arange_node(
58+
ComputeGraph& graph,
59+
const ValueRef start,
60+
const ValueRef end,
61+
const ValueRef step,
62+
const ValueRef out) {
63+
float start_val = 0.0f;
64+
float step_val = 1.0f;
65+
66+
if (graph.val_is_none(end)) {
67+
VK_THROW("arange: end must be specified!");
68+
}
69+
70+
if (!graph.val_is_none(start)) {
71+
if (graph.val_is_int(start)) {
72+
start_val = static_cast<float>(graph.extract_scalar<int64_t>(start));
73+
} else {
74+
start_val = graph.extract_scalar<float>(start);
75+
}
76+
}
77+
if (!graph.val_is_none(step)) {
78+
if (graph.val_is_int(step)) {
79+
step_val = static_cast<float>(graph.extract_scalar<int64_t>(step));
80+
} else {
81+
step_val = graph.extract_scalar<float>(step);
82+
}
83+
}
84+
85+
vTensorPtr t_out = graph.get_tensor(out);
86+
87+
api::utils::uvec3 global_size = t_out->image_extents();
88+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
89+
90+
std::string kernel_name("arange");
91+
kernel_name.reserve(kShaderNameReserve);
92+
93+
add_dtype_suffix(kernel_name, *t_out);
94+
95+
graph.execute_nodes().emplace_back(new ExecuteNode(
96+
graph,
97+
VK_KERNEL_FROM_STR(kernel_name),
98+
global_size,
99+
local_size,
100+
// Inputs and Outputs
101+
{{out, api::MemoryAccessType::WRITE}},
102+
// Shader params buffers
103+
{t_out->sizes_ubo(),
104+
graph.create_params_buffer(start_val),
105+
graph.create_params_buffer(step_val)},
106+
// Specialization Constants
107+
{},
108+
// Resizing Logic
109+
resize_arange_node,
110+
{start, end, step}));
111+
}
112+
113+
void arange(ComputeGraph& graph, const std::vector<ValueRef>& args) {
114+
return add_arange_node(graph, args[0], args[1], args[2], args[7]);
115+
}
116+
117+
REGISTER_OPERATORS {
118+
VK_REGISTER_OP(aten.arange.start_step, arange);
119+
}
120+
121+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@
1717

1818
namespace vkcompute {
1919

20+
inline int64_t normalize_idx(
21+
const int64_t index,
22+
const int64_t max,
23+
const int64_t default_value) {
24+
// INT64_MAX is passed when value is unspecified
25+
if (index == INT64_MAX) {
26+
return default_value;
27+
}
28+
if (index == default_value) {
29+
return index;
30+
}
31+
return normalize(index, max);
32+
}
33+
2034
void add_slice_tensor_out_node(
2135
ComputeGraph& graph,
2236
ValueRef in,
@@ -57,8 +71,8 @@ void add_slice_tensor_out_node(
5771
int64_t start = opt_start.value_or(0);
5872
int64_t end = opt_end.value_or(in_sizes[dim]);
5973

60-
VK_CHECK_COND((0 <= start) && (start < in_sizes[dim]));
61-
VK_CHECK_COND((0 <= end) && (end <= in_sizes[dim]));
74+
start = normalize_idx(start, in_sizes[dim], 0);
75+
end = normalize_idx(end, in_sizes[dim], in_sizes[dim]);
6276

6377
if (dim_index == kChannel4D) {
6478
// slice by channel

0 commit comments

Comments
 (0)