Skip to content

Commit 819cc03

Browse files
authored
[ET-VK] Replace Uniform buffers with push constants for copy op
Differential Revision: D66890851 Pull Request resolved: #7267
1 parent cb60bc7 commit 819cc03

File tree

3 files changed

+49
-50
lines changed

3 files changed

+49
-50
lines changed

backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,16 @@ ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
1818
${layout_declare_tensor(B, "r", "existing_out", DTYPE, STORAGE)}
1919
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2020

21-
${layout_declare_ubo(B, "ivec4", "out_sizes")}
22-
${layout_declare_ubo(B, "ivec4", "in_sizes")}
23-
24-
layout(set = 0, binding = 5) uniform PRECISION restrict CopyArgs {
21+
layout(push_constant) uniform restrict Block {
22+
ivec4 out_sizes;
23+
ivec4 in_sizes;
2524
// Operates on (x, y, z) logical extents.
26-
ivec3 range;
25+
// channel_range is stored in range.w
26+
ivec4 range;
2727
// Analogus to range variable in copy. It defines the # of channel being
2828
// copied.
29-
int channel_range;
30-
ivec3 dst_offset;
31-
int dst_channel_offset;
29+
// dst channel offset is stored in dst_offset.w
30+
ivec4 dst_offset;
3231
int src_channel_offset;
3332
};
3433

@@ -47,11 +46,11 @@ void main() {
4746
// Note: Unlike other shaders, the range is often not equal to the destination
4847
// texture extent.
4948
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
50-
if (any(greaterThanEqual(lpos, range))) {
49+
if (any(greaterThanEqual(lpos, range.xyz))) {
5150
return;
5251
}
5352

54-
const ivec3 out_lpos = lpos + dst_offset;
53+
const ivec3 out_lpos = lpos + dst_offset.xyz;
5554

5655
const ivec4 out_tidx = lpos_to_tidx(out_lpos, out_sizes, out_axis_map.w, packed_dim);
5756

@@ -61,12 +60,12 @@ void main() {
6160
ivec4 in_tidx = out_tidx;
6261
for (int i=0; i<4; i++) {
6362

64-
in_tidx[packed_dim] = out_tidx[packed_dim] - dst_channel_offset + i;
63+
in_tidx[packed_dim] = out_tidx[packed_dim] - dst_offset.w + i;
6564

6665
// Handle the partial update for begining of channel in an existing tensor.
6766
// If the source channel index is below zero or exceeds the range, we skip
6867
// updating the element to avoid overwriting existing data.
69-
if ((in_tidx[packed_dim] < 0) || (in_tidx[packed_dim] >= channel_range)) {
68+
if ((in_tidx[packed_dim] < 0) || (in_tidx[packed_dim] >= range.w)) {
7069
continue;
7170
}
7271

backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ layout(std430) buffer;
1717
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
1818
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
1919

20-
${layout_declare_ubo(B, "ivec3", "range", "ivec3", "src_offset", "ivec3", "dst_offset")}
20+
layout(push_constant) uniform restrict Block {
21+
ivec3 range;
22+
ivec3 src_offset;
23+
ivec3 dst_offset;
24+
};
2125

2226
#include "indexing_utils.h"
2327

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,6 @@ void add_copy_offset_node(
3333
add_dtype_suffix(kernel_name, *t_out);
3434
add_storage_type_suffix(kernel_name, *t_out);
3535

36-
const struct Block final {
37-
alignas(16) ivec3 range;
38-
alignas(16) ivec3 src_offset;
39-
alignas(16) ivec3 dst_offset;
40-
} offset_params{
41-
range,
42-
src_offset,
43-
dst_offset,
44-
};
45-
4636
auto shader = VK_KERNEL_FROM_STR(kernel_name);
4737

4838
graph.execute_nodes().emplace_back(new DispatchNode(
@@ -56,11 +46,18 @@ void add_copy_offset_node(
5646
{in, vkapi::kRead},
5747
},
5848
// Parameter buffers
59-
{
60-
graph.create_params_buffer(offset_params),
61-
},
49+
{},
6250
// Specialization Constants
63-
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)}));
51+
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
52+
nullptr,
53+
{},
54+
{
55+
PushConstantDataInfo(&range, sizeof(range), sizeof(utils::ivec4)),
56+
PushConstantDataInfo(
57+
&src_offset, sizeof(src_offset), sizeof(utils::ivec4)),
58+
PushConstantDataInfo(
59+
&dst_offset, sizeof(dst_offset), sizeof(utils::ivec4)),
60+
}));
6461
}
6562

6663
void add_copy_channel_offset_node(
@@ -128,28 +125,23 @@ void add_copy_channel_offset_node(
128125
// The shader combines the global invocation id and the dst_offset to get
129126
// the actual coordinate.
130127

131-
ivec3 dst_offset{
128+
const ivec3 dst_offset{
132129
0, 0, dst_first_z + batch_idx * utils::div_up_4(out_channels)};
133130

134-
uvec3 global_size{
131+
const uvec3 global_size{
135132
utils::safe_downcast<uint32_t>(dim_at<kWidth4D>(in_sizes)),
136133
utils::safe_downcast<uint32_t>(dim_at<kHeight4D>(in_sizes)),
137134
utils::safe_downcast<uint32_t>(dst_last_z - dst_first_z + 1)};
138-
uvec3 local_size = graph.create_local_wg_size(global_size);
139-
140-
const struct Block final {
141-
ivec3 range;
142-
int32_t channel_range;
143-
ivec3 dst_offset;
144-
int32_t dst_channel_offset;
145-
int32_t src_channel_offset;
146-
} channel_offset_params{
147-
utils::make_ivec3(global_size),
148-
channel_range,
149-
dst_offset,
150-
dst_channel_offset,
151-
src_channel_offset,
152-
};
135+
const uvec3 local_size = graph.create_local_wg_size(global_size);
136+
137+
const utils::ivec4 range_params = {
138+
static_cast<int>(global_size[0]),
139+
static_cast<int>(global_size[1]),
140+
static_cast<int>(global_size[2]),
141+
channel_range};
142+
143+
const utils::ivec4 offset_params = {
144+
dst_offset[0], dst_offset[1], dst_offset[2], dst_channel_offset};
153145

154146
auto shader = VK_KERNEL_FROM_STR(kernel_name);
155147

@@ -165,13 +157,17 @@ void add_copy_channel_offset_node(
165157
{in, vkapi::MemoryAccessType::READ},
166158
},
167159
// Parameter buffers
168-
{
169-
t_out->sizes_ubo(),
170-
t_in->sizes_ubo(),
171-
graph.create_params_buffer(channel_offset_params),
172-
},
160+
{},
173161
// Specialization Constants
174-
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)}));
162+
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
163+
nullptr,
164+
{},
165+
{graph.sizes_pc_of(out),
166+
graph.sizes_pc_of(in),
167+
PushConstantDataInfo(&range_params, sizeof(range_params)),
168+
PushConstantDataInfo(&offset_params, sizeof(offset_params)),
169+
PushConstantDataInfo(
170+
&src_channel_offset, sizeof(src_channel_offset))}));
175171
}
176172
}
177173

0 commit comments

Comments
 (0)