Skip to content

Commit 7b36eff

Browse files
committed
[ET-VK][Ops] aten.index_select
## The Operator `nn.Module` invocations of [`torch.index_select`](https://pytorch.org/docs/stable/generated/torch.index_select.html) get compiled to `aten.index_select.default` in the Edge Dialect, which carries the following signature. ``` - func: index_select(Tensor self, int dim, Tensor index) -> Tensor ``` ## Implementation This is a C-packing-only implementation. It is very similar to `aten.slice`: #3171 ``` - func: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) ``` It features a similar split between a shader for N,H,W and a shader for C, because copying from the C-dimension is more difficult due to C-packing. Both `index_select` and `slice` copy specific indices across 1 dimension. The difference is in the way these indices are specified. - `slice` uses `start=1`/`end=5`/`step=2` as three scalars for indices `1,3`. - `index_select` lists the exact indices inside a tensor e.g. `index=torch.tensor([1,3])`. Hence, `slice` uses a `offset=1` and `step=2` to compute input position. In `index_select`, we read the index tensor to compute input position. Differential Revision: [D57745489](https://our.internmc.facebook.com/intern/diff/D57745489/) [ghstack-poisoned]
1 parent 1343224 commit 7b36eff

File tree

9 files changed

+336
-5
lines changed

9 files changed

+336
-5
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __contains__(self, op):
9999
]
100100

101101
INDEXING_OPS = [
102+
exir_ops.edge.aten.index_select.default,
102103
exir_ops.edge.aten.select_copy.int,
103104
exir_ops.edge.aten.slice_copy.Tensor,
104105
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
21+
${layout_declare_tensor(2, "r", "t_idx", "int", STORAGE)}
22+
${layout_declare_ubo(3, "ivec4", "sizes")}
23+
${layout_declare_ubo(4, "int", "gpu_dim", "int", "stride")}
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
layout(constant_id = 3) const int packed_dim = C_DIM;
28+
29+
void main() {
30+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
31+
32+
if (pos_out_of_bounds(out_pos, sizes, packed_dim)) {
33+
return;
34+
}
35+
36+
const int out_idx = out_pos[gpu_dim] / stride;
37+
const int within_stride = out_pos[gpu_dim] % stride;
38+
const int in_idx = texelFetch(t_idx, ivec3(out_idx, 0, 0), 0).x;
39+
40+
ivec3 in_pos = out_pos;
41+
in_pos[gpu_dim] = in_idx * stride + within_stride;
42+
43+
imageStore(t_out, out_pos, texelFetch(t_in, in_pos, 0));
44+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
index_select:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
STORAGE: texture3d
6+
generate_variant_forall:
7+
DTYPE:
8+
- VALUE: half
9+
- VALUE: float
10+
shader_variants:
11+
- NAME: index_select
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
21+
${layout_declare_tensor(2, "r", "t_idx", "int", STORAGE)}
22+
${layout_declare_ubo(3, "ivec4", "out_sizes")}
23+
${layout_declare_ubo(4, "ivec4", "in_sizes")}
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
layout(constant_id = 3) const int packed_dim = C_DIM;
28+
29+
void main() {
30+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
31+
32+
if (pos_out_of_bounds(out_pos, out_sizes, packed_dim)) {
33+
return;
34+
}
35+
36+
const ivec4 idx = to_tensor_idx(out_pos, out_sizes, packed_dim);
37+
const ivec4 buffer_ixs = get_texel_nchw_buffer_ixs(idx, out_sizes, packed_dim);
38+
39+
vec4 out_texel;
40+
for (int i = 0; i < 4; ++i) {
41+
const ivec4 out_idx = from_nchw_buffer_i(buffer_ixs[i], out_sizes);
42+
int out_channel = out_idx.z;
43+
int in_channel = texelFetch(t_idx, ivec3(out_channel, 0, 0), 0).x;
44+
45+
ivec4 in_idx = out_idx;
46+
in_idx.z = in_channel;
47+
48+
ivec4 in_elem_pos = to_texture_elem_pos(in_idx, in_sizes, packed_dim);
49+
50+
vec4 in_texel = texelFetch(t_in, in_elem_pos.xyz, 0);
51+
52+
out_texel[i] = in_texel[in_elem_pos.w];
53+
}
54+
imageStore(t_out, out_pos, out_texel);
55+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
index_select_channel:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
STORAGE: texture3d
6+
generate_variant_forall:
7+
DTYPE:
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: index_select_channel
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
14+
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
17+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
18+
19+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
20+
21+
namespace vkcompute {
22+
23+
void check_index_select_args(
24+
const vTensor& in,
25+
const vTensor& idx,
26+
const vTensor& out) {
27+
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
28+
VK_CHECK_COND(check_memory_layout_is(idx, api::kChannelsPacked));
29+
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
30+
}
31+
32+
void add_index_select_channel_node(
33+
ComputeGraph& graph,
34+
ValueRef in,
35+
ValueRef idx,
36+
ValueRef out) {
37+
vTensorPtr t_in = graph.get_tensor(in);
38+
vTensorPtr t_idx = graph.get_tensor(idx);
39+
vTensorPtr t_out = graph.get_tensor(out);
40+
41+
check_index_select_args(*t_in, *t_idx, *t_out);
42+
43+
std::string kernel_name = "index_select_channel";
44+
kernel_name.reserve(kShaderNameReserve);
45+
add_dtype_suffix(kernel_name, *t_out);
46+
47+
api::utils::uvec3 global_size = t_out->image_extents();
48+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
49+
50+
graph.execute_nodes().emplace_back(new ExecuteNode(
51+
graph,
52+
VK_KERNEL_FROM_STR(kernel_name),
53+
global_size,
54+
local_size,
55+
{{out, api::MemoryAccessType::WRITE},
56+
{{in, idx}, api::MemoryAccessType::READ}},
57+
{t_out->sizes_ubo(), t_in->sizes_ubo()}));
58+
}
59+
60+
struct IndexSelectParams final {
61+
int32_t gpu_dim;
62+
int32_t stride;
63+
};
64+
65+
IndexSelectParams create_index_select_params(
66+
const int64_t dim_idx,
67+
const vTensor& in) {
68+
if (dim_idx == kWidth4D) {
69+
return {0, 1};
70+
} else if (dim_idx == kHeight4D) {
71+
return {1, 1};
72+
} else if (dim_idx == kBatch4D) {
73+
int64_t n_channels = dim_at(in.sizes(), kChannel4D);
74+
int64_t stride = api::utils::div_up_4(n_channels);
75+
return {2, static_cast<int32_t>(stride)};
76+
} else {
77+
VK_THROW("Unexpected dim_idx!");
78+
}
79+
}
80+
81+
void add_index_select_node(
82+
ComputeGraph& graph,
83+
ValueRef in,
84+
const int64_t dim_idx,
85+
ValueRef idx,
86+
ValueRef out) {
87+
vTensorPtr t_in = graph.get_tensor(in);
88+
vTensorPtr t_idx = graph.get_tensor(idx);
89+
vTensorPtr t_out = graph.get_tensor(out);
90+
91+
check_index_select_args(*t_in, *t_idx, *t_out);
92+
93+
IndexSelectParams params = create_index_select_params(dim_idx, *t_in);
94+
95+
std::string kernel_name = "index_select";
96+
kernel_name.reserve(kShaderNameReserve);
97+
add_dtype_suffix(kernel_name, *t_out);
98+
99+
api::utils::uvec3 global_size = t_out->image_extents();
100+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
101+
102+
graph.execute_nodes().emplace_back(new ExecuteNode(
103+
graph,
104+
VK_KERNEL_FROM_STR(kernel_name),
105+
global_size,
106+
local_size,
107+
{{out, api::MemoryAccessType::WRITE},
108+
{{in, idx}, api::MemoryAccessType::READ}},
109+
{t_out->sizes_ubo(), graph.create_params_buffer(params)}));
110+
}
111+
112+
int64_t get_dim_idx(ComputeGraph& graph, ValueRef in, ValueRef dim_ref) {
113+
vTensorPtr t_in = graph.get_tensor(in);
114+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
115+
dim = normalize(dim, t_in->dim());
116+
return normalize_to_dim_index(*t_in, dim);
117+
}
118+
119+
void index_select(ComputeGraph& graph, const std::vector<ValueRef>& args) {
120+
ValueRef in = prepack_if_tensor_ref(graph, args[0]);
121+
ValueRef dim_ref = args[1];
122+
ValueRef idx = prepack_if_tensor_ref(graph, args[2]);
123+
ValueRef out = args[3];
124+
125+
const int64_t dim_idx = get_dim_idx(graph, in, dim_ref);
126+
if (dim_idx == kChannel4D) {
127+
add_index_select_channel_node(graph, in, idx, out);
128+
} else {
129+
add_index_select_node(graph, in, dim_idx, idx, out);
130+
}
131+
}
132+
133+
REGISTER_OPERATORS {
134+
VK_REGISTER_OP(aten.index_select.default, index_select);
135+
}
136+
137+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,13 +395,34 @@ def get_slice_inputs():
395395
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
396396

397397
test_suite.dtypes = ["at::kFloat"]
398-
test_suite.layouts = [
399-
"api::kChannelsPacked",
400-
]
398+
test_suite.layouts = ["api::kChannelsPacked"]
401399
test_suite.data_gen = "make_seq_tensor"
402400
return test_suite
403401

404402

403+
def get_index_select_inputs():
404+
Test = namedtuple("VkIndexSelectTest", ["self", "dim", "index"])
405+
Test.__new__.__defaults__ = (None, 0, None)
406+
407+
test_cases = []
408+
409+
for i in range(4):
410+
test_cases += [
411+
Test(self=[9, 9, 9, 9], dim=i, index=[0]),
412+
Test(self=[9, 9, 9, 9], dim=i, index=[2]),
413+
Test(self=[9, 9, 9, 9], dim=i, index=[0, 2]),
414+
Test(self=[9, 9, 9, 9], dim=i, index=[3, 1]),
415+
Test(self=[9, 9, 9, 9], dim=i, index=[5, 5]),
416+
Test(self=[9, 9, 9, 9], dim=i, index=[2, 3, 4, 5, 7, 10]),
417+
]
418+
419+
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
420+
421+
test_suite.dtypes = ["at::kFloat"]
422+
test_suite.layouts = ["api::kChannelsPacked"]
423+
return test_suite
424+
425+
405426
def get_unsqueeze_inputs():
406427
test_suite = VkTestSuite(
407428
[
@@ -795,6 +816,7 @@ def get_gelu_inputs():
795816
"aten.view_copy.default": get_view_inputs(),
796817
"aten.slice_copy.Tensor": get_slice_inputs(),
797818
"aten.slice.Tensor": get_slice_inputs(),
819+
"aten.index_select.default": get_index_select_inputs(),
798820
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
799821
"aten.clone.default": get_clone_inputs(),
800822
"aten.repeat.default": get_repeat_inputs(),

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,12 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
154154
ret_str = f"{cpp_type} {arg.name} = "
155155

156156
if cpp_type == AT_TENSOR:
157-
ret_str += f"{self.suite_def.data_gen}({init_list_str(data)}, test_dtype);"
157+
if arg.name == "index":
158+
ret_str += f"make_index_tensor({init_list_str(data)});"
159+
else:
160+
ret_str += (
161+
f"{self.suite_def.data_gen}({init_list_str(data)}, test_dtype);"
162+
)
158163
elif cpp_type == OPT_AT_TENSOR:
159164
if str(data) == "None":
160165
ret_str += "std::nullopt;"
@@ -267,7 +272,7 @@ def generate_suite_cpp(self) -> str:
267272
268273
at::Tensor make_seq_tensor(
269274
std::vector<int64_t> sizes,
270-
at::ScalarType dtype = at::kFloat) {{
275+
at::ScalarType dtype = at::kFloat) {{
271276
int64_t n = 1;
272277
for (auto size: sizes) {{
273278
n *= size;
@@ -283,6 +288,16 @@ def generate_suite_cpp(self) -> str:
283288
return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone();
284289
}}
285290
291+
292+
at::Tensor make_index_tensor(std::vector<int64_t> indices) {{
293+
int64_t size = static_cast<int64_t>(indices.size());
294+
at::ScalarType dtype = at::kInt;
295+
296+
// from_blob doesn't take ownership of data. Hence must create a copy as
297+
// "values" will go out of scope.
298+
return at::from_blob(indices.data(), {{size}}, dtype).detach().clone();
299+
}}
300+
286301
{test_suites_cpp}
287302
"""
288303

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,3 +1255,39 @@ def forward(self, x):
12551255
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
12561256
custom_pass=[MeanToSumDiv()],
12571257
)
1258+
1259+
def test_vulkan_backend_index_select_channel(self):
1260+
class IndexSelectModule(torch.nn.Module):
1261+
def __init__(self, dim, indices):
1262+
super().__init__()
1263+
self.dim = dim
1264+
self.index = torch.tensor(indices, dtype=torch.int32)
1265+
1266+
def forward(self, x):
1267+
return torch.index_select(x, self.dim, self.index)
1268+
1269+
sample_inputs = (torch.arange(96).reshape(2, 8, 2, 3).float(),)
1270+
1271+
self.lower_module_and_test_output(
1272+
IndexSelectModule(dim=1, indices=[2, 3, 5, 6, 7]),
1273+
sample_inputs,
1274+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1275+
)
1276+
1277+
def test_vulkan_backend_index_select(self):
1278+
class IndexSelectModule(torch.nn.Module):
1279+
def __init__(self, dim, indices):
1280+
super().__init__()
1281+
self.dim = dim
1282+
self.index = torch.tensor(indices, dtype=torch.int32)
1283+
1284+
def forward(self, x):
1285+
return torch.index_select(x, self.dim, self.index)
1286+
1287+
sample_inputs = (torch.arange(144).reshape(12, 1, 3, 4).float(),)
1288+
1289+
self.lower_module_and_test_output(
1290+
IndexSelectModule(dim=0, indices=[1, 3, 5, 7, 8, 9, 10, 11, 2, 3]),
1291+
sample_inputs,
1292+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1293+
)

0 commit comments

Comments
 (0)