Skip to content

Commit bf9a946

Browse files
committed
[ET-VK][7/n] Slice, with lots of codegen improvements
Pull Request resolved: #3171 1. Add slice operation. Instead of using copy in LI, we implement a simple shader with offsets. 2. Improvement in codegen. - add support of optional variables - improve indent of the code, for better readability - allow user to specify tensor value generation, possible to generate sequential values for easier debugging for index operations - sample code improve test-case specification, particularly with long and optional values. ghstack-source-id: 223254861 Differential Revision: [D56295985](https://our.internmc.facebook.com/intern/diff/D56295985/)
1 parent fa433cb commit bf9a946

File tree

11 files changed

+454
-18
lines changed

11 files changed

+454
-18
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
1212

13+
#include <optional>
14+
1315
#include <executorch/backends/vulkan/runtime/api/api.h>
1416

1517
#include <executorch/backends/vulkan/runtime/graph/GraphConfig.h>
@@ -184,6 +186,15 @@ class ComputeGraph final {
184186
VK_THROW("Cannot extract scalar from Value with type ", value.type());
185187
}
186188

189+
template <typename T>
190+
std::optional<T> extract_optional_scalar(const ValueRef idx) {
191+
if (val_is_none(idx)) {
192+
return ::std::nullopt;
193+
} else {
194+
return extract_scalar<T>(idx);
195+
}
196+
}
197+
187198
inline std::vector<std::unique_ptr<PrepackNode>>& prepack_nodes() {
188199
return prepack_nodes_;
189200
}

backends/vulkan/runtime/graph/Logging.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/api/Utils.h>
1212

13+
#include <optional>
1314
#include <ostream>
1415
#include <vector>
1516

@@ -33,4 +34,14 @@ inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec4& v) {
3334
return api::utils::operator<<(os, v);
3435
}
3536

37+
template <typename T>
38+
inline std::ostream& operator<<(std::ostream& os, const std::optional<T>& opt) {
39+
os << "[";
40+
if (opt) {
41+
os << opt.value();
42+
}
43+
os << "]";
44+
return os;
45+
}
46+
3647
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88

99
#define divup4(x) ((x + 3) / 4)
1010

11-
// Input: idx is a ivec4 user-level coordinate, sizes is the tensor shape
12-
// Output: buffer_idx in the continuous nchw-buffer.
11+
// Input: idx is a ivec4 user-level (w, h, c, n) coordinate, sizes is the tensor
12+
// shape Output: buffer_idx in the continuous nchw-buffer.
1313
#define to_buffer_i(idx, sizes) \
1414
(idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + \
1515
idx.w * sizes.z * sizes.y * sizes.x)
1616

1717
// Inverse of to_buffer_i
1818
// Input: buffer_idx in the continuous nchw-buffer, sizes is the tensor shape
19-
// Output: ivec4 user-level coorindate
19+
// Output: ivec4 user-level (w, h, c, n) coorindate
2020
#define from_buffer_i(buf_i, sizes) \
2121
ivec4( \
2222
buf_i % sizes.x, \
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
20+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
21+
22+
layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes {
23+
uvec4 data;
24+
}
25+
out_sizes;
26+
27+
layout(set = 0, binding = 3) uniform PRECISION restrict SliceArg {
28+
int dim;
29+
int offset;
30+
int step;
31+
// Used when dim=batch. Stride is the # of plances for each batch value.
32+
int stride;
33+
}
34+
slice_arg;
35+
36+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
37+
38+
void main() {
39+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
40+
41+
const ivec4 idx = to_tensor_idx_C_packed(out_pos, out_sizes.data);
42+
43+
if (any(greaterThanEqual(idx, out_sizes.data))) {
44+
return;
45+
}
46+
47+
ivec3 in_pos = out_pos;
48+
49+
int index = out_pos[slice_arg.dim] / slice_arg.stride;
50+
int within_stride = out_pos[slice_arg.dim] % slice_arg.stride;
51+
52+
in_pos[slice_arg.dim] = slice_arg.offset * slice_arg.stride + index * slice_arg.step *
53+
slice_arg.stride + within_stride;
54+
55+
imageStore(image_out, out_pos, texelFetch(image_in, in_pos, 0));
56+
57+
}
58+
59+
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
slice_batch_height_width:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: slice_batch_height_width
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
16+
#define to_tensor_idx to_tensor_idx_${PACKING}
17+
#define to_texture_pos_elem to_texture_pos_elem_${PACKING}
18+
#define get_packed_stride get_packed_stride_${PACKING}
19+
20+
21+
layout(std430) buffer;
22+
23+
#include "indexing_utils.h"
24+
25+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
26+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
27+
28+
layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes {
29+
uvec4 data;
30+
}
31+
out_sizes;
32+
33+
layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes {
34+
uvec4 out_cpu_sizes;
35+
};
36+
37+
layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes {
38+
uvec4 in_gpu_sizes;
39+
};
40+
41+
layout(set = 0, binding = 5) uniform PRECISION restrict SliceArg {
42+
int offset;
43+
int step;
44+
}
45+
slice_arg;
46+
47+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
48+
49+
void main() {
50+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
51+
52+
const ivec4 idx = to_tensor_idx_C_packed(out_pos, out_sizes.data);
53+
54+
if (any(greaterThanEqual(idx, out_sizes.data))) {
55+
return;
56+
}
57+
58+
// We map the output pos using the buffer index. For each index in the texel,
59+
// we calculate the source whcn-coordinate amended with offset-ed channel
60+
// value. Then we calculate the actual texture position from the
61+
// whcn-coordinate.
62+
63+
const uint base_index = to_buffer_i(idx, out_cpu_sizes);
64+
uvec4 buf_indices =
65+
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes);
66+
67+
vec4 outex;
68+
for (int i=0;i<4;i++) {
69+
ivec4 user_coor = from_buffer_i(buf_indices[i], out_cpu_sizes);
70+
71+
int in_channel = user_coor.z;
72+
73+
ivec4 in_user_coor = user_coor;
74+
in_user_coor.z = slice_arg.offset + in_channel * slice_arg.step;
75+
76+
ivec4 in_pow_elem = to_texture_pos_elem_C_packed(
77+
in_user_coor,
78+
in_gpu_sizes);
79+
80+
vec4 v = texelFetch(image_in, in_pow_elem.xyz, 0);
81+
82+
outex[i] = v[in_pow_elem.w];
83+
}
84+
imageStore(image_out, out_pos, outex);
85+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
slice_channel:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: float
8+
PACKING:
9+
- VALUE: C_packed
10+
shader_variants:
11+
- NAME: slice_channel
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
void add_slice_tensor_out_node(
21+
ComputeGraph& graph,
22+
ValueRef in,
23+
ValueRef dim_ref,
24+
ValueRef opt_start_ref,
25+
ValueRef opt_end_ref,
26+
ValueRef step_ref,
27+
ValueRef out) {
28+
vTensorPtr t_in = graph.get_tensor(in);
29+
vTensorPtr t_out = graph.get_tensor(out);
30+
31+
VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked));
32+
VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked));
33+
34+
// Need normalize the dim
35+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
36+
37+
VK_CHECK_COND(
38+
-t_in->dim() <= dim && dim < t_in->dim(),
39+
"dim must be in range of [-self.dim(), self.dim()), but current dim's value is ",
40+
dim,
41+
" and self.dim() = ",
42+
t_in->dim());
43+
44+
dim = normalize(dim, t_in->dim());
45+
46+
// Create a dim value as in the underlying dim is 4-dimension.
47+
int64_t nchw_dim = dim + (4 - t_in->dim());
48+
49+
std::optional<int64_t> opt_start =
50+
graph.extract_optional_scalar<int64_t>(opt_start_ref);
51+
std::optional<int64_t> opt_end =
52+
graph.extract_optional_scalar<int64_t>(opt_end_ref);
53+
int64_t step = graph.extract_scalar<int64_t>(step_ref);
54+
55+
const auto in_sizes = t_in->sizes();
56+
const auto out_sizes = t_out->sizes();
57+
58+
int64_t start = opt_start.value_or(0);
59+
int64_t end = opt_end.value_or(in_sizes[dim]);
60+
61+
VK_CHECK_COND((0 <= start) && (start < in_sizes[dim]));
62+
VK_CHECK_COND((0 <= end) && (end <= in_sizes[dim]));
63+
64+
if (nchw_dim == 1) {
65+
// slice by channel
66+
std::string kernel_name = "slice_channel";
67+
kernel_name.reserve(kShaderNameReserve);
68+
add_dtype_suffix(kernel_name, *t_out);
69+
add_memory_layout_suffix(kernel_name, *t_out);
70+
71+
api::utils::uvec3 global_size = t_out->extents();
72+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
73+
74+
const struct Block final {
75+
int offset;
76+
int step;
77+
} params{
78+
static_cast<int32_t>(start),
79+
static_cast<int32_t>(step),
80+
};
81+
82+
graph.execute_nodes().emplace_back(new ExecuteNode(
83+
graph,
84+
VK_KERNEL_FROM_STR(kernel_name),
85+
global_size,
86+
local_size,
87+
{{out, api::MemoryAccessType::WRITE},
88+
{in, api::MemoryAccessType::READ}},
89+
{t_out->gpu_sizes_ubo(),
90+
t_out->cpu_sizes_ubo(),
91+
t_in->gpu_sizes_ubo(),
92+
graph.create_params_buffer(params)}));
93+
94+
} else {
95+
// GPU's coordinate is in x, y, z
96+
int64_t gpu_dim = -1;
97+
int64_t stride = 1;
98+
if (nchw_dim == 3) {
99+
gpu_dim = 0; // width: x dimension in gpu
100+
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
101+
} else if (nchw_dim == 2) {
102+
gpu_dim = 1; // height: y dimension
103+
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
104+
} else if (nchw_dim == 0) {
105+
gpu_dim = 2; // batch: z dimension
106+
107+
// Due to channel packing, each batch value is span over stride planes
108+
int64_t n_channels = dim_at<Dim4D::Channel>(in_sizes);
109+
stride = api::utils::div_up<int64_t>(n_channels, 4ll);
110+
} else {
111+
VK_THROW("Unexpected ncwh_dim!");
112+
}
113+
114+
std::string kernel_name = "slice_batch_height_width";
115+
kernel_name.reserve(kShaderNameReserve);
116+
add_dtype_suffix(kernel_name, *t_out);
117+
118+
api::utils::uvec3 global_size = t_out->extents();
119+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
120+
121+
const struct Block final {
122+
int dim;
123+
int offset;
124+
int step;
125+
int stride;
126+
} params{
127+
static_cast<int32_t>(gpu_dim),
128+
static_cast<int32_t>(start),
129+
static_cast<int32_t>(step),
130+
static_cast<int32_t>(stride),
131+
};
132+
133+
graph.execute_nodes().emplace_back(new ExecuteNode(
134+
graph,
135+
VK_KERNEL_FROM_STR(kernel_name),
136+
global_size,
137+
local_size,
138+
{{out, api::MemoryAccessType::WRITE},
139+
{in, api::MemoryAccessType::READ}},
140+
{t_out->gpu_sizes_ubo(), graph.create_params_buffer(params)}));
141+
}
142+
}
143+
144+
void slice_tensor_out(ComputeGraph& graph, const std::vector<ValueRef>& args) {
145+
return add_slice_tensor_out_node(
146+
graph,
147+
args[0],
148+
args[1], // dim
149+
args[2], // optional start
150+
args[3], // optional end
151+
args[4], // step
152+
args[5]);
153+
}
154+
155+
REGISTER_OPERATORS {
156+
VK_REGISTER_OP(aten.slice_copy.Tensor, slice_tensor_out);
157+
}
158+
159+
} // namespace vkcompute

0 commit comments

Comments
 (0)