Skip to content

Commit 118afa0

Browse files
authored
[ET-VK][ez][Refactor] Re-order DispatchNode arguments to match shader layout spec (#10700)
## Context As title. Note that this PR was written mainly with a meta internal coding agent. This PR re-orders the arguments to `DispatchNode` to be more intuitive. I asked the agent: --- I want to change the constructor of the class ``` explicit DispatchNode( ComputeGraph& graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const utils::uvec3& local_workgroup_size, const std::vector<ArgGroup>& args, const vkapi::ParamsBindList& params, const vkapi::SpecVarList& spec_vars = {}, const ResizeFunction& resize_fn = nullptr, const std::vector<ValueRef>& resize_args = {}, const std::vector<PushConstantDataInfo>& push_constants = {}); ``` to instead be ``` explicit DispatchNode( ComputeGraph& graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const utils::uvec3& local_workgroup_size, const std::vector<ArgGroup>& args, const vkapi::ParamsBindList& params,, const std::vector<PushConstantDataInfo>& push_constants = {}, const vkapi::SpecVarList& spec_vars = {}, const std::vector<ValueRef>& resize_args = {}, const ResizeFunction& resize_fn = nullptr); ``` Can you make this change and update the callsites as well? --- The motivation is to have the arguments match the order in which parameter UBOs, push constant blocks, and specialization variables are declared in a GLSL shader. The order of `resize`_args` and `resize_fn` was also swapped in the interest of having the function pointer be the last argument. It will also make more sense in a following diff where a `DynamicDispatchNode` class is introduced, which will allow selecting a different compute shader depending on input sizes. As a small additional change, I also asked the agent to --- Go through all the files under `xplat/executorch/backends/vulkan/runtime/graph/ops/impl` Change `vkapi::MemoryAccessType::WRITE` to `vkapi::kWrite` and `vkapi::MemoryAccessType::READ` to `vkapi::kRead` for each file. --- Differential Revision: [D74203482](https://our.internmc.facebook.com/intern/diff/D74203482/)
1 parent 4945f4e commit 118afa0

34 files changed

+316
-167
lines changed

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ DispatchNode::DispatchNode(
2121
const utils::uvec3& local_workgroup_size,
2222
const std::vector<ArgGroup>& args,
2323
const vkapi::ParamsBindList& params,
24+
const std::vector<PushConstantDataInfo>& push_constants,
2425
const vkapi::SpecVarList& spec_vars,
25-
const ResizeFunction& resize_fn,
2626
const std::vector<ValueRef>& resize_args,
27-
const std::vector<PushConstantDataInfo>& push_constants)
27+
const ResizeFunction& resize_fn)
2828
: ExecuteNode(resize_fn, resize_args, args, shader.kernel_name),
2929
shader_(shader),
3030
global_workgroup_size_(global_workgroup_size),

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ class DispatchNode final : public ExecuteNode {
3333
const utils::uvec3& local_workgroup_size,
3434
const std::vector<ArgGroup>& args,
3535
const vkapi::ParamsBindList& params,
36+
const std::vector<PushConstantDataInfo>& push_constants = {},
3637
const vkapi::SpecVarList& spec_vars = {},
37-
const ResizeFunction& resize_fn = nullptr,
3838
const std::vector<ValueRef>& resize_args = {},
39-
const std::vector<PushConstantDataInfo>& push_constants = {});
39+
const ResizeFunction& resize_fn = nullptr);
4040

4141
~DispatchNode() override = default;
4242

backends/vulkan/runtime/graph/ops/impl/Arange.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,19 @@ void add_arange_node(
9494
graph.create_global_wg_size(out),
9595
graph.create_local_wg_size(out),
9696
// Inputs and Outputs
97-
{{out, vkapi::MemoryAccessType::WRITE}},
97+
{{out, vkapi::kWrite}},
9898
// Shader params buffers
9999
{t_out->sizes_ubo(),
100100
graph.create_params_buffer(start_val),
101101
graph.create_params_buffer(step_val)},
102+
// Push Constants
103+
{},
102104
// Specialization Constants
103105
{},
106+
// Resize Args
107+
{start, end, step},
104108
// Resizing Logic
105-
resize_arange_node,
106-
{start, end, step}));
109+
resize_arange_node));
107110
}
108111

109112
void arange(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,19 @@ void add_native_batch_norm_node(
9090
VK_KERNEL_FROM_STR(kernel_name),
9191
graph.create_global_wg_size(out_ref),
9292
graph.create_local_wg_size(out_ref),
93-
{{out_ref, vkapi::MemoryAccessType::WRITE},
94-
{{in_ref, arg_weight, arg_bias, arg_mean, arg_var},
95-
vkapi::MemoryAccessType::READ}},
93+
{{out_ref, vkapi::kWrite},
94+
{{in_ref, arg_weight, arg_bias, arg_mean, arg_var}, vkapi::kRead}},
9695
{t_out->logical_limits_ubo(),
9796
graph.create_params_buffer(epsilon),
98-
graph.create_params_buffer(num_texel_per_batch)}));
97+
graph.create_params_buffer(num_texel_per_batch)},
98+
// Push Constants
99+
{},
100+
// Specialization Constants
101+
{},
102+
// Resize Args
103+
{},
104+
// Resizing Logic
105+
nullptr));
99106
}
100107

101108
void native_batch_norm(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,20 @@ void add_binary_op_texture_node(
8484
graph.create_global_wg_size(out),
8585
graph.create_local_wg_size(out),
8686
// Inputs and Outputs
87-
{{out, vkapi::MemoryAccessType::WRITE},
88-
{{arg1, arg2}, vkapi::MemoryAccessType::READ}},
87+
{{out, vkapi::kWrite}, {{arg1, arg2}, vkapi::kRead}},
8988
// Shader params buffers
9089
{},
91-
// Specialization Constants
92-
{t_out->hashed_layout(), t_in1->hashed_layout(), t_in2->hashed_layout()},
93-
// Resizing Logic
94-
resize_binary_op_node,
95-
{},
90+
// Push Constants
9691
{{graph.sizes_pc_of(out),
9792
graph.sizes_pc_of(arg1),
9893
graph.sizes_pc_of(arg2),
99-
PushConstantDataInfo(&binary_ops_params, sizeof(binary_ops_params))}}));
94+
PushConstantDataInfo(&binary_ops_params, sizeof(binary_ops_params))}},
95+
// Specialization Constants
96+
{t_out->hashed_layout(), t_in1->hashed_layout(), t_in2->hashed_layout()},
97+
// Resize Args
98+
{},
99+
// Resizing Logic
100+
resize_binary_op_node));
100101
}
101102

102103
void add_binary_op_buffer_node(
@@ -127,17 +128,10 @@ void add_binary_op_buffer_node(
127128
graph.create_global_wg_size(out),
128129
graph.create_local_wg_size(out),
129130
// Inputs and Outputs
130-
{{out, vkapi::MemoryAccessType::WRITE},
131-
{{in1, in2}, vkapi::MemoryAccessType::READ}},
131+
{{out, vkapi::kWrite}, {{in1, in2}, vkapi::kRead}},
132132
// Shader params buffers
133133
{},
134-
// Specialization Constants
135-
{graph.packed_dim_of(out),
136-
graph.packed_dim_of(in1),
137-
graph.packed_dim_of(in2)},
138-
// Resizing Logic
139-
resize_binary_op_node,
140-
{},
134+
// Push Constants
141135
{{
142136
graph.sizes_pc_of(in1),
143137
graph.sizes_pc_of(in2),
@@ -146,7 +140,15 @@ void add_binary_op_buffer_node(
146140
graph.strides_pc_of(in2),
147141
graph.numel_pc_of(out),
148142
PushConstantDataInfo(&alpha_val, sizeof(float)),
149-
}}));
143+
}},
144+
// Specialization Constants
145+
{graph.packed_dim_of(out),
146+
graph.packed_dim_of(in1),
147+
graph.packed_dim_of(in2)},
148+
// Resize Args
149+
{},
150+
// Resizing Logic
151+
resize_binary_op_node));
150152
}
151153

152154
void add_binary_op_node(

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@ void add_clone_node(
5050
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
5151
// Parameter Buffers
5252
{t_out->logical_limits_ubo()},
53+
// Push Constants
54+
{},
5355
// Specialization Constants
5456
{},
57+
// Resize Args
58+
{},
5559
// Resizing Logic
5660
resize_clone_node));
5761
}
@@ -74,8 +78,12 @@ void add_image_to_buffer_node(
7478
{{buffer, vkapi::kWrite}, {image, vkapi::kRead}},
7579
// Parameter Buffers
7680
{graph.sizes_ubo(image), graph.strides_ubo(buffer)},
81+
// Push Constants
82+
{},
7783
// Specialization Constants
7884
{graph.hashed_layout_of(image)},
85+
// Resize Args
86+
{},
7987
// Resizing Logic
8088
resize_clone_node));
8189
}
@@ -98,8 +106,12 @@ void add_buffer_to_image_node(
98106
{{image, vkapi::kWrite}, {buffer, vkapi::kRead}},
99107
// Parameter Buffers
100108
{graph.sizes_ubo(image), graph.strides_ubo(buffer)},
109+
// Push Constants
110+
{},
101111
// Specialization Constants
102112
{graph.hashed_layout_of(image)},
113+
// Resize Args
114+
{},
103115
// Resizing Logic
104116
resize_clone_node));
105117
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -444,16 +444,17 @@ void add_conv2d_node(
444444
wg_size,
445445
graph.create_local_wg_size(wg_size),
446446
// Inputs and Outputs
447-
{{out, vkapi::MemoryAccessType::WRITE},
448-
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
447+
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
449448
// Shader params buffers
450449
param_buffers,
450+
// Push Constants
451+
push_constants,
451452
// Specialization Constants
452453
{},
453-
// Resizing Logic
454-
resize_conv2d_node,
454+
// Resize Args
455455
{weight_data, stride, padding, dilation, transposed, output_padding},
456-
push_constants));
456+
// Resizing Logic
457+
resize_conv2d_node));
457458
}
458459

459460
void add_conv1d_node(
@@ -548,23 +549,25 @@ void add_conv1d_node(
548549
global_size,
549550
local_size,
550551
// Inputs and Outputs
551-
{{out, vkapi::MemoryAccessType::WRITE},
552-
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
552+
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
553553
// Shader params buffers
554554
{
555555
t_out->logical_limits_ubo(),
556556
t_in->sizes_ubo(),
557557
graph.create_params_buffer(kernel_params),
558558
graph.create_params_buffer(out_params),
559559
},
560+
// Push Constants
561+
{},
560562
// Specialization Constants
561563
{t_out->hashed_layout(),
562564
t_in->hashed_layout(),
563565
t_weight->hashed_layout(),
564566
t_bias->hashed_layout()},
567+
// Resize Args
568+
{weight, stride, padding, dilation},
565569
// Resizing Logic
566-
resize_conv1d_node,
567-
{weight, stride, padding, dilation}));
570+
resize_conv1d_node));
568571
}
569572

570573
void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,22 @@ void add_copy_offset_node(
5050
},
5151
// Parameter buffers
5252
{},
53+
// Push Constants
54+
{
55+
PushConstantDataInfo(&range, sizeof(range), sizeof(ivec4)),
56+
PushConstantDataInfo(&src_offset, sizeof(src_offset), sizeof(ivec4)),
57+
PushConstantDataInfo(&dst_offset, sizeof(dst_offset), sizeof(ivec4)),
58+
},
5359
// Specialization Constants
5460
{graph.hashed_layout_of(out),
5561
graph.hashed_layout_of(in),
5662
(calc_out_pos_using_src_chnl ? 1
5763
: calc_in_pos_using_dst_chnl ? 2
5864
: 0)},
59-
nullptr,
65+
// Resize Args
6066
{},
61-
{
62-
PushConstantDataInfo(&range, sizeof(range), sizeof(ivec4)),
63-
PushConstantDataInfo(&src_offset, sizeof(src_offset), sizeof(ivec4)),
64-
PushConstantDataInfo(&dst_offset, sizeof(dst_offset), sizeof(ivec4)),
65-
}));
67+
// Resizing Logic
68+
nullptr));
6669
}
6770

6871
void add_copy_packed_dim_offset_node(
@@ -138,22 +141,25 @@ void add_copy_packed_dim_offset_node(
138141
graph.create_local_wg_size(global_wg_size),
139142
// Inputs and Outputs
140143
{
141-
{out, vkapi::MemoryAccessType::WRITE},
142-
{out, vkapi::MemoryAccessType::READ},
143-
{in, vkapi::MemoryAccessType::READ},
144+
{out, vkapi::kWrite},
145+
{out, vkapi::kRead},
146+
{in, vkapi::kRead},
144147
},
145148
// Parameter buffers
146149
{},
147-
// Specialization Constants
148-
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
149-
nullptr,
150-
{},
150+
// Push Constants
151151
{
152152
PushConstantDataInfo(
153153
&final_range, sizeof(final_range), sizeof(ivec4)),
154154
PushConstantDataInfo(&src_offset, sizeof(src_offset), sizeof(ivec4)),
155155
PushConstantDataInfo(&dst_offset, sizeof(dst_offset), sizeof(ivec4)),
156-
}));
156+
},
157+
// Specialization Constants
158+
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
159+
// Resize Args
160+
{},
161+
// Resizing Logic
162+
nullptr));
157163
}
158164

159165
void add_copy_channel_offset_node(
@@ -248,22 +254,24 @@ void add_copy_channel_offset_node(
248254
local_size,
249255
// Inputs and Outputs
250256
{
251-
{out, vkapi::MemoryAccessType::WRITE},
252-
{out, vkapi::MemoryAccessType::READ},
253-
{in, vkapi::MemoryAccessType::READ},
257+
{out, vkapi::kWrite},
258+
{out, vkapi::kRead},
259+
{in, vkapi::kRead},
254260
},
255261
// Parameter buffers
256262
{},
257-
// Specialization Constants
258-
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
259-
nullptr,
260-
{},
263+
// Push Constants
261264
{graph.sizes_pc_of(out),
262265
graph.sizes_pc_of(in),
263266
PushConstantDataInfo(&range_params, sizeof(range_params)),
264267
PushConstantDataInfo(&offset_params, sizeof(offset_params)),
265-
PushConstantDataInfo(
266-
&src_channel_offset, sizeof(src_channel_offset))}));
268+
PushConstantDataInfo(&src_channel_offset, sizeof(src_channel_offset))},
269+
// Specialization Constants
270+
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
271+
// Resize Args
272+
{},
273+
// Resizing Logic
274+
nullptr));
267275
}
268276
}
269277

backends/vulkan/runtime/graph/ops/impl/Embedding.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,16 @@ void add_embedding_node(
5858
{
5959
t_out->sizes_ubo(),
6060
},
61+
// Push Constants
62+
{},
63+
// Specialization Constants
6164
{t_out->hashed_layout(),
6265
t_in->hashed_layout(),
63-
t_weight->hashed_layout()}));
66+
t_weight->hashed_layout()},
67+
// Resize Args
68+
{},
69+
// Resizing Logic
70+
nullptr));
6471
}
6572

6673
void embedding(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/Flip.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,12 @@ void add_flip_node(
7474
graph.sizes_ubo(out),
7575
graph.create_params_buffer(dim_bitmap),
7676
},
77+
// Push Constants
78+
{},
7779
// Specialization Constants
7880
{},
81+
// Resize Args
82+
{},
7983
// Resizing Logic
8084
resize_flip_node));
8185
}

backends/vulkan/runtime/graph/ops/impl/Full.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,17 @@ void add_full_node(
5050
graph.create_global_wg_size(out),
5151
graph.create_local_wg_size(out),
5252
// Inputs and Outputs
53-
{{out, vkapi::MemoryAccessType::WRITE}},
53+
{{out, vkapi::kWrite}},
5454
// Shader params buffers
5555
{t_out->sizes_ubo(), graph.create_params_buffer(fill_value_val)},
56+
// Push Constants
57+
{},
5658
// Specialization Constants
5759
{SV(t_out->packed_dim())},
60+
// Resize Args
61+
{size_or_in},
5862
// Resizing Logic
59-
resize_full_node,
60-
{size_or_in}));
63+
resize_full_node));
6164
}
6265

6366
void full(ComputeGraph& graph, const std::vector<ValueRef>& args) {

0 commit comments

Comments
 (0)