Skip to content

Commit ae250c0

Browse files
yipjustinfacebook-github-bot
authored andcommitted
copy_channel_offsets node (#3351)
Summary: Pull Request resolved: #3351 1. Add a node `copy_channel_offsets` specifically for copying along the channel dimension, it needs extra attention at the boundaries due to channel packing. 1.1. `copy_channel_offsets` will be useful for `aten.cat` and `aten.split`. 2. Create `etvk.*` operators to facilitate testing. Add test case for both `copy_offset` and `copy_channel_offset`. ghstack-source-id: 224214136 Reviewed By: jorgep31415 Differential Revision: D56554426 fbshipit-source-id: c4190480a15359ec38af34b043812696c2ccabcb
1 parent 9314595 commit ae250c0

File tree

7 files changed

+575
-30
lines changed

7 files changed

+575
-30
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
20+
layout(set = 0, binding = 1) uniform PRECISION sampler3D existing_out;
21+
layout(set = 0, binding = 2) uniform PRECISION sampler3D image_in;
22+
23+
layout(set = 0, binding = 3) uniform PRECISION restrict CopyArgs {
24+
ivec4 out_sizes;
25+
ivec4 in_sizes;
26+
// Analogus to range variable in copy. It defines the # of channel being
27+
// copied.
28+
int channel_range;
29+
int src_channel_offset;
30+
int dst_channel_offset;
31+
int unused;
32+
// Operates on (x, y, z) extents.
33+
ivec3 range;
34+
int unused1;
35+
ivec3 dst_offset;
36+
int unused2;
37+
};
38+
39+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40+
41+
layout(constant_id = 3) const int packed_dim = C_DIM;
42+
43+
void main() {
44+
// Note: Unlike other shaders, the range is often not equal to the destination
45+
// texture extent.
46+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
47+
if (any(greaterThanEqual(pos, range))) {
48+
return;
49+
}
50+
51+
const ivec3 out_pos = pos + dst_offset;
52+
53+
const ivec4 out_whcn = to_tensor_idx(out_pos, out_sizes, packed_dim);
54+
55+
// First read the existing values to make sure the boundary values stay.
56+
VEC4_T v = VEC4_T(texelFetch(existing_out, out_pos, 0));
57+
58+
for (int i=0; i<4; i++) {
59+
ivec4 in_whcn = out_whcn;
60+
61+
in_whcn.z = out_whcn.z - dst_channel_offset + i;
62+
63+
// Handle the partial update for begining of channel in an existing tensor.
64+
// If the source channel index is below zero or exceeds the range, we skip
65+
// updating the element to avoid overwriting existing data.
66+
if ((in_whcn.z < 0) || (in_whcn.z >= channel_range)) {
67+
continue;
68+
}
69+
70+
// Readjust for the source offset.
71+
in_whcn.z = in_whcn.z + src_channel_offset;
72+
73+
ivec4 in_elem_pos = to_texture_elem_pos(in_whcn, in_sizes, packed_dim);
74+
v[i] = VEC4_T(texelFetch(image_in, in_elem_pos.xyz, 0))[in_elem_pos.w];
75+
}
76+
77+
imageStore(image_out, out_pos, v);
78+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
copy_channel_offset:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: copy_channel_offset

backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,12 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13-
#define VEC4_T ${texel_type(DTYPE)}
14-
1513
layout(std430) buffer;
1614

17-
#include "indexing_utils.h"
18-
1915
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
2016
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
2117

22-
layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits {
23-
ivec3 out_limits;
24-
};
25-
26-
layout(set = 0, binding = 3) uniform PRECISION restrict InLimits {
27-
ivec3 in_limits;
28-
};
29-
30-
31-
32-
layout(set = 0, binding = 4) uniform PRECISION restrict CopyArgs {
18+
layout(set = 0, binding = 2) uniform PRECISION restrict CopyArgs {
3319
ivec3 range;
3420
int unused0;
3521
ivec3 src_offset;

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 169 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,39 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1415

1516
namespace vkcompute {
1617

18+
using api::utils::ivec3;
19+
using api::utils::uvec3;
20+
1721
void add_copy_offset_node(
1822
ComputeGraph& graph,
1923
const ValueRef in,
20-
const api::utils::ivec3& range,
21-
const api::utils::ivec3& src_offset,
22-
const api::utils::ivec3& dst_offset,
24+
const ivec3& range,
25+
const ivec3& src_offset,
26+
const ivec3& dst_offset,
2327
const ValueRef out) {
2428
vTensorPtr t_in = graph.get_tensor(in);
2529
vTensorPtr t_out = graph.get_tensor(out);
2630

27-
VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked));
28-
VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked));
29-
3031
std::string kernel_name = "copy_offset";
3132
kernel_name.reserve(kShaderNameReserve);
3233
add_dtype_suffix(kernel_name, *t_out);
3334

34-
api::utils::uvec3 global_size = api::utils::make_uvec3(range);
35-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
35+
uvec3 global_size = api::utils::make_uvec3(range);
36+
uvec3 local_size = adaptive_work_group_size(global_size);
3637

3738
const struct Block final {
38-
api::utils::ivec3 range;
39+
ivec3 range;
3940
int32_t unused0;
40-
api::utils::ivec3 src_offset;
41+
ivec3 src_offset;
4142
int32_t unused1;
42-
api::utils::ivec3 dst_offset;
43+
ivec3 dst_offset;
4344
int32_t unused2;
4445
} offset_params{
4546
range,
@@ -58,13 +59,166 @@ void add_copy_offset_node(
5859
global_size,
5960
local_size,
6061
// Inputs and Outputs
61-
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
62+
{
63+
{out, api::MemoryAccessType::WRITE},
64+
{in, api::MemoryAccessType::READ},
65+
},
6266
// Parameter buffers
63-
{t_out->texture_limits_ubo(),
64-
t_in->texture_limits_ubo(),
65-
graph.create_params_buffer(offset_params)},
67+
{graph.create_params_buffer(offset_params)},
6668
// Specialization Constants
6769
{}));
6870
}
6971

72+
void add_copy_channel_offset_node(
73+
ComputeGraph& graph,
74+
const ValueRef in,
75+
int32_t channel_range,
76+
int32_t src_channel_offset,
77+
int32_t dst_channel_offset,
78+
const ValueRef out) {
79+
vTensorPtr t_in = graph.get_tensor(in);
80+
vTensorPtr t_out = graph.get_tensor(out);
81+
82+
// Likely need to prepad these numbers.
83+
std::vector<int64_t> in_sizes = t_in->sizes();
84+
std::vector<int64_t> out_sizes = t_out->sizes();
85+
86+
VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked));
87+
VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked));
88+
89+
// NOTE: This function should be able to support 1d and 2d tensors when
90+
// range=1, src_offset=dst_offset=1.
91+
VK_CHECK_COND(t_in->dim() >= 3, "Src dim should be at least 3");
92+
VK_CHECK_COND(t_out->dim() >= 3, "Dst dim should be at least 3");
93+
94+
VK_CHECK_COND(
95+
dim_at<Dim4D::Channel>(in_sizes) >= src_channel_offset + channel_range,
96+
"Source channel plus range should be less than or equal to input tensor's channel size");
97+
VK_CHECK_COND(
98+
dim_at<Dim4D::Channel>(out_sizes) >= dst_channel_offset + channel_range,
99+
"Source channel and range should be less than or equal to input tensor's channel size");
100+
101+
VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative");
102+
VK_CHECK_COND(
103+
src_channel_offset >= 0, "Src channel offset must be non-negative");
104+
VK_CHECK_COND(
105+
dst_channel_offset >= 0, "Dst channel offset must be non-negative");
106+
107+
std::string kernel_name = "copy_channel_offset";
108+
kernel_name.reserve(kShaderNameReserve);
109+
add_dtype_suffix(kernel_name, *t_out);
110+
111+
int32_t out_channels = dim_at<Dim4D::Channel>(out_sizes);
112+
113+
// Copy one batch at a time.
114+
for (int batch_idx = 0; batch_idx < dim_at<Dim4D::Batch>(in_sizes);
115+
batch_idx++) {
116+
// Mapping the tensor NCHW coordinates into texture XYZ coordinates
117+
int32_t dst_first_z = dst_channel_offset / 4;
118+
int32_t dst_last_z = (dst_channel_offset + channel_range - 1) / 4;
119+
120+
// We copy the entire width and height dimension. For the channel dimension,
121+
// we use the z-dimension of the global_size to specify the texture range.
122+
// The shader combines the global invocation id and the dst_offset to get
123+
// the actual coordinate.
124+
125+
ivec3 dst_offset{
126+
0, 0, dst_first_z + batch_idx * api::utils::div_up(out_channels, 4)};
127+
128+
uvec3 global_size{
129+
dim_at<Dim4D::Width>(in_sizes),
130+
dim_at<Dim4D::Height>(in_sizes),
131+
api::utils::safe_downcast<uint32_t>(dst_last_z - dst_first_z + 1)};
132+
133+
uvec3 local_size = adaptive_work_group_size(global_size);
134+
135+
const struct Block final {
136+
api::utils::ivec4 out_sizes;
137+
api::utils::ivec4 in_sizes;
138+
int32_t channel_range;
139+
int32_t src_channel_offset;
140+
int32_t dst_channel_offset;
141+
int32_t unused;
142+
ivec3 range;
143+
int32_t unused1;
144+
ivec3 dst_offset;
145+
int32_t unused2;
146+
147+
} channel_offset_params{
148+
api::utils::make_whcn_ivec4(out_sizes),
149+
api::utils::make_whcn_ivec4(in_sizes),
150+
channel_range,
151+
src_channel_offset,
152+
dst_channel_offset,
153+
0,
154+
api::utils::make_ivec3(global_size),
155+
0,
156+
dst_offset,
157+
0,
158+
};
159+
160+
auto shader = VK_KERNEL_FROM_STR(kernel_name);
161+
162+
graph.execute_nodes().emplace_back(new ExecuteNode(
163+
graph,
164+
VK_KERNEL_FROM_STR(kernel_name),
165+
global_size,
166+
local_size,
167+
// Inputs and Outputs
168+
{
169+
{out, api::MemoryAccessType::WRITE},
170+
{out, api::MemoryAccessType::READ},
171+
{in, api::MemoryAccessType::READ},
172+
},
173+
// Parameter buffers
174+
{graph.create_params_buffer(channel_offset_params)},
175+
// Specialization Constants
176+
{}));
177+
}
178+
}
179+
180+
void add_copy_offset_node(
181+
ComputeGraph& graph,
182+
ValueRef in,
183+
ValueRef range_ref,
184+
ValueRef src_offset_ref,
185+
ValueRef dst_offset_ref,
186+
ValueRef out) {
187+
ivec3 range = api::utils::make_ivec3(*graph.get_int_list(range_ref));
188+
ivec3 src_offset =
189+
api::utils::make_ivec3(*graph.get_int_list(src_offset_ref));
190+
ivec3 dst_offset =
191+
api::utils::make_ivec3(*graph.get_int_list(dst_offset_ref));
192+
193+
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out);
194+
}
195+
196+
void copy_offset(ComputeGraph& graph, const std::vector<ValueRef>& args) {
197+
add_copy_offset_node(graph, args[0], args[1], args[2], args[3], args[4]);
198+
}
199+
200+
void copy_channel_offset(
201+
ComputeGraph& graph,
202+
const std::vector<ValueRef>& args) {
203+
ValueRef in = args[0];
204+
ValueRef channel_range_ref = args[1];
205+
ValueRef src_channel_offset_ref = args[2];
206+
ValueRef dst_channel_offset_ref = args[3];
207+
ValueRef out = args[4];
208+
209+
auto channel_range = graph.extract_scalar<int64_t>(channel_range_ref);
210+
auto src_channel_offset =
211+
graph.extract_scalar<int64_t>(src_channel_offset_ref);
212+
auto dst_channel_offset =
213+
graph.extract_scalar<int64_t>(dst_channel_offset_ref);
214+
215+
add_copy_channel_offset_node(
216+
graph, in, channel_range, src_channel_offset, dst_channel_offset, out);
217+
}
218+
219+
REGISTER_OPERATORS {
220+
VK_REGISTER_OP(etvk.copy_offset, copy_offset);
221+
VK_REGISTER_OP(etvk.copy_channel_offset, copy_channel_offset);
222+
}
223+
70224
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Copy.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414

1515
namespace vkcompute {
1616

17+
// add_copy_offset_node resumes the vkCmdCopyImage command. It copies the
18+
// texture extents specified by the range, src_offset, and dst_offset (all are
19+
// in texture coordinate (x, y, z) from the input image to the output image.
20+
//
21+
// It is possible to have input and output to point to the same image
22+
// object. But when the source range and destination range overlap, the behavior
23+
// is undefined.
1724
void add_copy_offset_node(
1825
ComputeGraph& graph,
1926
const ValueRef in,
@@ -22,4 +29,25 @@ void add_copy_offset_node(
2229
const api::utils::ivec3& dst_offset,
2330
const ValueRef out);
2431

32+
// add_copy_channel_offset_node behaves similar to add_copy_node, except that it
33+
// works on the channel dimensions of the tensor (up to 4 dimensions in NCHW).
34+
// The range and offset arguments are in the tensor coordinate. It assumes the
35+
// underlying texture is channel-packed.
36+
//
37+
// This function is specialized implementation for copying
38+
// channel packed values. The complication comes from when reading / writing the
39+
// channel dimension on indices that are not aligned to packing, we will need
40+
// be careful about the boundaries.
41+
//
42+
// It achieves the following:
43+
// out[:, dst_channel_offset:dst_channel_offset + channel_range, :, :] =
44+
// in [:, src_channel_offset:src_channel_offset + channel_range, :, :]
45+
void add_copy_channel_offset_node(
46+
ComputeGraph& graph,
47+
const ValueRef in,
48+
int32_t channel_range,
49+
int32_t src_channel_offset,
50+
int32_t dst_channel_offset,
51+
const ValueRef out);
52+
2553
} // namespace vkcompute

0 commit comments

Comments
 (0)