Skip to content

Commit c665c17

Browse files
jorgep31415facebook-github-bot
authored andcommitted
aten.index_select (#3744)
Summary: Pull Request resolved: #3744 ## The Operator `nn.Module` invocations of [`torch.index_select`](https://pytorch.org/docs/stable/generated/torch.index_select.html) get compiled to `aten.index_select.default` in the Edge Dialect, which carries the following signature. ``` - func: index_select(Tensor self, int dim, Tensor index) -> Tensor ``` ## Implementation This is a C-packing-only implementation. It is very similar to `aten.slice`: #3171 ``` - func: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) ``` It features a similar split between a shader for N,H,W and a shader for C, because copying from the C-dimension is more difficult due to C-packing. Both `index_select` and `slice` copy specific indices across 1 dimension. The difference is in the way these indices are specified. - `slice` uses `start=1`/`end=5`/`step=2` as three scalars for indices `1,3`. - `index_select` lists the exact indices inside a tensor e.g. `index=torch.tensor([1,3])`. Hence, `slice` uses a `offset=1` and `step=2` to compute input position. In `index_select`, we read the index tensor to compute input position. Reviewed By: copyrightly Differential Revision: D57745489 fbshipit-source-id: 4ef7f1a13d4bf74af20fe61149dbd5d461aaab0c
1 parent 24b37f2 commit c665c17

File tree

9 files changed

+336
-5
lines changed

9 files changed

+336
-5
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __contains__(self, op):
9999
]
100100

101101
INDEXING_OPS = [
102+
exir_ops.edge.aten.index_select.default,
102103
exir_ops.edge.aten.select_copy.int,
103104
exir_ops.edge.aten.slice_copy.Tensor,
104105
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
21+
${layout_declare_tensor(2, "r", "t_idx", "int", STORAGE)}
22+
${layout_declare_ubo(3, "ivec4", "sizes")}
23+
${layout_declare_ubo(4, "int", "gpu_dim", "int", "stride")}
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
layout(constant_id = 3) const int packed_dim = C_DIM;
28+
29+
void main() {
30+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
31+
32+
if (pos_out_of_bounds(out_pos, sizes, packed_dim)) {
33+
return;
34+
}
35+
36+
const int out_idx = out_pos[gpu_dim] / stride;
37+
const int within_stride = out_pos[gpu_dim] % stride;
38+
const int in_idx = texelFetch(t_idx, ivec3(out_idx, 0, 0), 0).x;
39+
40+
ivec3 in_pos = out_pos;
41+
in_pos[gpu_dim] = in_idx * stride + within_stride;
42+
43+
imageStore(t_out, out_pos, texelFetch(t_in, in_pos, 0));
44+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
index_select:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
STORAGE: texture3d
6+
generate_variant_forall:
7+
DTYPE:
8+
- VALUE: half
9+
- VALUE: float
10+
- VALUE: int
11+
shader_variants:
12+
- NAME: index_select
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
21+
${layout_declare_tensor(2, "r", "t_idx", "int", STORAGE)}
22+
${layout_declare_ubo(3, "ivec4", "out_sizes")}
23+
${layout_declare_ubo(4, "ivec4", "in_sizes")}
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
layout(constant_id = 3) const int packed_dim = C_DIM;
28+
29+
void main() {
30+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
31+
32+
if (pos_out_of_bounds(out_pos, out_sizes, packed_dim)) {
33+
return;
34+
}
35+
36+
const ivec4 idx = to_tensor_idx(out_pos, out_sizes, packed_dim);
37+
const ivec4 buffer_ixs = get_texel_nchw_buffer_ixs(idx, out_sizes, packed_dim);
38+
39+
VEC4_T out_texel;
40+
for (int i = 0; i < 4; ++i) {
41+
const ivec4 out_idx = from_nchw_buffer_i(buffer_ixs[i], out_sizes);
42+
int out_channel = out_idx.z;
43+
int in_channel = texelFetch(t_idx, ivec3(out_channel, 0, 0), 0).x;
44+
45+
ivec4 in_idx = out_idx;
46+
in_idx.z = in_channel;
47+
48+
ivec4 in_elem_pos = to_texture_elem_pos(in_idx, in_sizes, packed_dim);
49+
50+
VEC4_T in_texel = texelFetch(t_in, in_elem_pos.xyz, 0);
51+
52+
out_texel[i] = in_texel[in_elem_pos.w];
53+
}
54+
imageStore(t_out, out_pos, out_texel);
55+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
index_select_channel:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
STORAGE: texture3d
6+
generate_variant_forall:
7+
DTYPE:
8+
- VALUE: half
9+
- VALUE: float
10+
- VALUE: int
11+
shader_variants:
12+
- NAME: index_select_channel
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
void check_index_select_args(
21+
const vTensor& in,
22+
const vTensor& idx,
23+
const vTensor& out) {
24+
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
25+
VK_CHECK_COND(check_memory_layout_is(idx, api::kChannelsPacked));
26+
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
27+
}
28+
29+
void add_index_select_channel_node(
30+
ComputeGraph& graph,
31+
ValueRef in,
32+
ValueRef idx,
33+
ValueRef out) {
34+
vTensorPtr t_in = graph.get_tensor(in);
35+
vTensorPtr t_idx = graph.get_tensor(idx);
36+
vTensorPtr t_out = graph.get_tensor(out);
37+
38+
check_index_select_args(*t_in, *t_idx, *t_out);
39+
40+
std::string kernel_name = "index_select_channel";
41+
kernel_name.reserve(kShaderNameReserve);
42+
add_dtype_suffix(kernel_name, *t_out);
43+
44+
api::utils::uvec3 global_size = t_out->image_extents();
45+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
46+
47+
graph.execute_nodes().emplace_back(new ExecuteNode(
48+
graph,
49+
VK_KERNEL_FROM_STR(kernel_name),
50+
global_size,
51+
local_size,
52+
{{out, api::MemoryAccessType::WRITE},
53+
{{in, idx}, api::MemoryAccessType::READ}},
54+
{t_out->sizes_ubo(), t_in->sizes_ubo()}));
55+
}
56+
57+
struct IndexSelectParams final {
58+
int32_t gpu_dim;
59+
int32_t stride;
60+
};
61+
62+
IndexSelectParams create_index_select_params(
63+
const int64_t dim_idx,
64+
const vTensor& in) {
65+
if (dim_idx == kWidth4D) {
66+
return {0, 1};
67+
} else if (dim_idx == kHeight4D) {
68+
return {1, 1};
69+
} else if (dim_idx == kBatch4D) {
70+
int64_t n_channels = dim_at(in.sizes(), kChannel4D);
71+
int64_t stride = api::utils::div_up_4(n_channels);
72+
return {2, static_cast<int32_t>(stride)};
73+
} else {
74+
VK_THROW("Unexpected dim_idx!");
75+
}
76+
}
77+
78+
void add_index_select_node(
79+
ComputeGraph& graph,
80+
ValueRef in,
81+
const int64_t dim_idx,
82+
ValueRef idx,
83+
ValueRef out) {
84+
vTensorPtr t_in = graph.get_tensor(in);
85+
vTensorPtr t_idx = graph.get_tensor(idx);
86+
vTensorPtr t_out = graph.get_tensor(out);
87+
88+
check_index_select_args(*t_in, *t_idx, *t_out);
89+
90+
IndexSelectParams params = create_index_select_params(dim_idx, *t_in);
91+
92+
std::string kernel_name = "index_select";
93+
kernel_name.reserve(kShaderNameReserve);
94+
add_dtype_suffix(kernel_name, *t_out);
95+
96+
api::utils::uvec3 global_size = t_out->image_extents();
97+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
98+
99+
graph.execute_nodes().emplace_back(new ExecuteNode(
100+
graph,
101+
VK_KERNEL_FROM_STR(kernel_name),
102+
global_size,
103+
local_size,
104+
{{out, api::MemoryAccessType::WRITE},
105+
{{in, idx}, api::MemoryAccessType::READ}},
106+
{t_out->sizes_ubo(), graph.create_params_buffer(params)}));
107+
}
108+
109+
int64_t get_dim_idx(ComputeGraph& graph, ValueRef in, ValueRef dim_ref) {
110+
vTensorPtr t_in = graph.get_tensor(in);
111+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
112+
dim = normalize(dim, t_in->dim());
113+
return normalize_to_dim_index(*t_in, dim);
114+
}
115+
116+
void index_select(ComputeGraph& graph, const std::vector<ValueRef>& args) {
117+
ValueRef in = prepack_if_tensor_ref(graph, args[0]);
118+
ValueRef dim_ref = args[1];
119+
ValueRef idx = prepack_if_tensor_ref(graph, args[2]);
120+
ValueRef out = args[3];
121+
122+
const int64_t dim_idx = get_dim_idx(graph, in, dim_ref);
123+
if (dim_idx == kChannel4D) {
124+
add_index_select_channel_node(graph, in, idx, out);
125+
} else {
126+
add_index_select_node(graph, in, dim_idx, idx, out);
127+
}
128+
}
129+
130+
REGISTER_OPERATORS {
131+
VK_REGISTER_OP(aten.index_select.default, index_select);
132+
}
133+
134+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,13 +416,34 @@ def get_slice_inputs():
416416
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
417417

418418
test_suite.dtypes = ["at::kFloat"]
419-
test_suite.layouts = [
420-
"api::kChannelsPacked",
421-
]
419+
test_suite.layouts = ["api::kChannelsPacked"]
422420
test_suite.data_gen = "make_seq_tensor"
423421
return test_suite
424422

425423

424+
def get_index_select_inputs():
425+
Test = namedtuple("VkIndexSelectTest", ["self", "dim", "index"])
426+
Test.__new__.__defaults__ = (None, 0, None)
427+
428+
test_cases = []
429+
430+
for i in range(4):
431+
test_cases += [
432+
Test(self=[9, 9, 9, 9], dim=i, index=[0]),
433+
Test(self=[9, 9, 9, 9], dim=i, index=[2]),
434+
Test(self=[9, 9, 9, 9], dim=i, index=[0, 2]),
435+
Test(self=[9, 9, 9, 9], dim=i, index=[3, 1]),
436+
Test(self=[9, 9, 9, 9], dim=i, index=[5, 5]),
437+
Test(self=[9, 9, 9, 9], dim=i, index=[2, 3, 4, 5, 7]),
438+
]
439+
440+
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
441+
442+
test_suite.dtypes = ["at::kFloat"]
443+
test_suite.layouts = ["api::kChannelsPacked"]
444+
return test_suite
445+
446+
426447
def get_unsqueeze_inputs():
427448
test_suite = VkTestSuite(
428449
[
@@ -816,6 +837,7 @@ def get_gelu_inputs():
816837
"aten.view_copy.default": get_view_inputs(),
817838
"aten.slice_copy.Tensor": get_slice_inputs(),
818839
"aten.slice.Tensor": get_slice_inputs(),
840+
"aten.index_select.default": get_index_select_inputs(),
819841
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
820842
"aten.clone.default": get_clone_inputs(),
821843
"aten.repeat.default": get_repeat_inputs(),

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,12 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
154154
ret_str = f"{cpp_type} {arg.name} = "
155155

156156
if cpp_type == AT_TENSOR:
157-
ret_str += f"{self.suite_def.data_gen}({init_list_str(data)}, test_dtype);"
157+
if arg.name == "index":
158+
ret_str += f"make_index_tensor({init_list_str(data)});"
159+
else:
160+
ret_str += (
161+
f"{self.suite_def.data_gen}({init_list_str(data)}, test_dtype);"
162+
)
158163
elif cpp_type == OPT_AT_TENSOR:
159164
if str(data) == "None":
160165
ret_str += "std::nullopt;"
@@ -267,7 +272,7 @@ def generate_suite_cpp(self) -> str:
267272
268273
at::Tensor make_seq_tensor(
269274
std::vector<int64_t> sizes,
270-
at::ScalarType dtype = at::kFloat) {{
275+
at::ScalarType dtype = at::kFloat) {{
271276
int64_t n = 1;
272277
for (auto size: sizes) {{
273278
n *= size;
@@ -283,6 +288,16 @@ def generate_suite_cpp(self) -> str:
283288
return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone();
284289
}}
285290
291+
292+
at::Tensor make_index_tensor(std::vector<int64_t> indices) {{
293+
int64_t size = static_cast<int64_t>(indices.size());
294+
at::ScalarType dtype = at::kInt;
295+
296+
// from_blob doesn't take ownership of data. Hence must create a copy as
297+
// "values" will go out of scope.
298+
return at::from_blob(indices.data(), {{size}}, dtype).detach().clone();
299+
}}
300+
286301
{test_suites_cpp}
287302
"""
288303

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,3 +1272,39 @@ def forward(self, x):
12721272
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
12731273
custom_pass=[MeanToSumDiv()],
12741274
)
1275+
1276+
def test_vulkan_backend_index_select_int(self):
1277+
class IndexSelectModule(torch.nn.Module):
1278+
def __init__(self, dim, indices):
1279+
super().__init__()
1280+
self.dim = dim
1281+
self.index = torch.tensor(indices, dtype=torch.int32)
1282+
1283+
def forward(self, x):
1284+
return torch.index_select(x, self.dim, self.index)
1285+
1286+
sample_inputs = (torch.arange(96).reshape(2, 8, 2, 3).int(),)
1287+
1288+
self.lower_module_and_test_output(
1289+
IndexSelectModule(dim=1, indices=[2, 3, 5, 6, 7]),
1290+
sample_inputs,
1291+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1292+
)
1293+
1294+
def test_vulkan_backend_index_select(self):
1295+
class IndexSelectModule(torch.nn.Module):
1296+
def __init__(self, dim, indices):
1297+
super().__init__()
1298+
self.dim = dim
1299+
self.index = torch.tensor(indices, dtype=torch.int32)
1300+
1301+
def forward(self, x):
1302+
return torch.index_select(x, self.dim, self.index)
1303+
1304+
sample_inputs = (torch.arange(144).reshape(12, 1, 3, 4).float(),)
1305+
1306+
self.lower_module_and_test_output(
1307+
IndexSelectModule(dim=0, indices=[1, 3, 5, 7, 8, 9, 10, 11, 2, 3]),
1308+
sample_inputs,
1309+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1310+
)

0 commit comments

Comments
 (0)