Skip to content

Commit af0a246

Browse files
authored
[ET-VK] Migrate ops to use DynamicDispatchNode (#11353)
## Changes * Migrate operators that are used in the llama model to use `DynamicDispatchNode` instead of `DispatchNode` ## Motivation `DynamicDispatchNode` is a subclass of `DispatchNode` that allows dynamic selection of compute shaders, global and local work group sizing whenever the command buffer is encoded. This is critical for ensuring optimum performance when input shapes are dynamic, since it allows operators to select the best compute shader for the input conditions and also to adjust global work group sizing to launch the minimum number of work groups necessary. Without this change, performance of llama 3.2 1B with dynamic shapes enabled is terrible (< 1 tok/s) because global work group sizing is determined based on maximum tensor sizes, which is based on the maximum sequence length. In practice, the sequence length dimension of tensors (even during the prefill phase) will not approach the maximum. This results in a lot of inactive threads launched during compute shader dispatches. Differential Revision: [D75878398](https://our.internmc.facebook.com/intern/diff/D75878398/)
1 parent f50bfa9 commit af0a246

20 files changed

+521
-264
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,15 @@ ValueRef ComputeGraph::add_symint(const int32_t val) {
449449
return idx;
450450
}
451451

452+
ValueRef ComputeGraph::get_or_add_value_for_int(const int64_t val) {
453+
for (int i = 0; i < values_.size(); ++i) {
454+
if (values_.at(i).isInt() && values_.at(i).toInt() == val) {
455+
return i;
456+
}
457+
}
458+
return add_scalar(val);
459+
}
460+
452461
ValueRef ComputeGraph::set_input_tensor(
453462
const ValueRef idx,
454463
const bool use_staging) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,13 @@ class ComputeGraph final {
604604

605605
ValueRef add_symint(const int32_t val);
606606

607+
/*
608+
* Searches the graph's value list for a Int value with the specified value.
609+
* If one is found, returns the index of the value. Otherwise, add a new value
610+
* and return the index of the new value.
611+
*/
612+
ValueRef get_or_add_value_for_int(const int64_t val);
613+
607614
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
608615
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
609616

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode(
2525
const ResizeFunction& resize_fn)
2626
: DispatchNode(
2727
graph,
28-
vkapi::ShaderInfo(),
29-
{1u, 1u, 1u},
28+
pick_shader_fn(&graph, args, resize_args),
3029
{1u, 1u, 1u},
30+
{8u, 8u, 1u},
3131
args,
3232
params,
3333
push_constants,
@@ -37,7 +37,6 @@ DynamicDispatchNode::DynamicDispatchNode(
3737
pick_shader_fn_(pick_shader_fn),
3838
pick_global_wg_fn_(pick_global_wg_fn),
3939
pick_local_wg_fn_(pick_local_wg_fn) {
40-
shader_ = pick_shader_fn(&graph, args, resize_args);
4140
global_workgroup_size_ =
4241
pick_global_wg_fn(&graph, shader_, args, resize_args);
4342
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
@@ -30,8 +31,8 @@ void check_binary_op_args(
3031
void resize_binary_op_node(
3132
ComputeGraph* graph,
3233
const std::vector<ArgGroup>& args,
33-
const std::vector<ValueRef>& extra_args) {
34-
(void)extra_args;
34+
const std::vector<ValueRef>& resize_args) {
35+
(void)resize_args;
3536
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
3637

3738
// TODO(T183442143): Verify tensors are broadcastable.
@@ -78,11 +79,11 @@ void add_binary_op_texture_node(
7879
add_storage_type_suffix(kernel_name, *t_out);
7980
add_dtype_suffix(kernel_name, *t_out);
8081

81-
graph.execute_nodes().emplace_back(new DispatchNode(
82+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
8283
graph,
8384
VK_KERNEL_FROM_STR(kernel_name),
84-
graph.create_global_wg_size(out),
85-
graph.create_local_wg_size(out),
85+
default_pick_global_wg_size,
86+
default_pick_local_wg_size,
8687
// Inputs and Outputs
8788
{{out, vkapi::kWrite}, {{arg1, arg2}, vkapi::kRead}},
8889
// Shader params buffers
@@ -122,11 +123,11 @@ void add_binary_op_buffer_node(
122123
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
123124
add_dtype_suffix(kernel_name, graph.dtype_of(out));
124125

125-
graph.execute_nodes().emplace_back(new DispatchNode(
126+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
126127
graph,
127128
VK_KERNEL_FROM_STR(kernel_name),
128-
graph.create_global_wg_size(out),
129-
graph.create_local_wg_size(out),
129+
default_pick_global_wg_size,
130+
default_pick_local_wg_size,
130131
// Inputs and Outputs
131132
{{out, vkapi::kWrite}, {{in1, in2}, vkapi::kRead}},
132133
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
1415

1516
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
@@ -21,8 +22,8 @@ namespace vkcompute {
2122
void resize_clone_node(
2223
ComputeGraph* graph,
2324
const std::vector<ArgGroup>& args,
24-
const std::vector<ValueRef>& extra_args) {
25-
(void)extra_args;
25+
const std::vector<ValueRef>& resize_args) {
26+
(void)resize_args;
2627
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
2728
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
2829
// TODO: support for when dimensionality doesn't match, i.e. clone is used to
@@ -41,11 +42,11 @@ void add_clone_node(
4142
std::string kernel_name = "clone";
4243
add_dtype_suffix(kernel_name, *t_out);
4344

44-
graph.execute_nodes().emplace_back(new DispatchNode(
45+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
4546
graph,
4647
VK_KERNEL_FROM_STR(kernel_name),
47-
graph.create_global_wg_size(out),
48-
graph.create_local_wg_size(out),
48+
default_pick_global_wg_size,
49+
default_pick_local_wg_size,
4950
// Inputs and Outputs
5051
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
5152
// Parameter Buffers
@@ -60,6 +61,17 @@ void add_clone_node(
6061
resize_clone_node));
6162
}
6263

64+
utils::uvec3 clone_image_to_buffer_global_wg_size(
65+
ComputeGraph* graph,
66+
const vkapi::ShaderInfo& shader,
67+
const std::vector<ArgGroup>& args,
68+
const std::vector<ValueRef>& resize_args) {
69+
(void)shader;
70+
(void)resize_args;
71+
const ValueRef image = args.at(1).refs.at(0);
72+
return graph->create_global_wg_size(image);
73+
}
74+
6375
void add_image_to_buffer_node(
6476
ComputeGraph& graph,
6577
const ValueRef image,
@@ -68,12 +80,11 @@ void add_image_to_buffer_node(
6880
add_dtype_suffix(kernel_name, graph.dtype_of(image));
6981
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
7082

71-
utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
72-
graph.execute_nodes().emplace_back(new DispatchNode(
83+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
7384
graph,
7485
shader,
75-
global_wg_size,
76-
graph.create_local_wg_size(global_wg_size),
86+
clone_image_to_buffer_global_wg_size,
87+
default_pick_local_wg_size,
7788
// Input and Outputs
7889
{{buffer, vkapi::kWrite}, {image, vkapi::kRead}},
7990
// Parameter Buffers
@@ -96,12 +107,11 @@ void add_buffer_to_image_node(
96107
add_dtype_suffix(kernel_name, graph.dtype_of(image));
97108
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
98109

99-
utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
100-
graph.execute_nodes().emplace_back(new DispatchNode(
110+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
101111
graph,
102112
shader,
103-
global_wg_size,
104-
graph.create_local_wg_size(global_wg_size),
113+
default_pick_global_wg_size,
114+
default_pick_local_wg_size,
105115
// Input and Outputs
106116
{{image, vkapi::kWrite}, {buffer, vkapi::kRead}},
107117
// Parameter Buffers

backends/vulkan/runtime/graph/ops/impl/Common.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ utils::uvec3 default_pick_global_wg_size(
1414
ComputeGraph* graph,
1515
const vkapi::ShaderInfo& shader,
1616
const std::vector<ArgGroup>& args,
17-
const std::vector<ValueRef>& additional_args) {
17+
const std::vector<ValueRef>& resize_args) {
1818
(void)shader;
19+
(void)resize_args;
1920
const ValueRef out = args.at(0).refs.at(0);
2021
return graph->create_global_wg_size(out);
2122
}
@@ -25,8 +26,10 @@ utils::uvec3 default_pick_local_wg_size(
2526
const vkapi::ShaderInfo& shader,
2627
const utils::uvec3& global_workgroup_size,
2728
const std::vector<ArgGroup>& args,
28-
const std::vector<ValueRef>& additional_args) {
29+
const std::vector<ValueRef>& resize_args) {
2930
(void)shader;
31+
(void)args;
32+
(void)resize_args;
3033
return graph->create_local_wg_size(global_workgroup_size);
3134
}
3235

backends/vulkan/runtime/graph/ops/impl/Common.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,23 @@ namespace vkcompute {
1717
* Creates a global workgroup size based on the first output tensor in the args.
1818
* This is a utility function that extracts the output tensor from
1919
* args.at(0).refs.at(0) and calls graph->create_global_wg_size(out) on it.
20-
*
21-
* @param graph The ComputeGraph instance
22-
* @param args Vector of ArgGroup containing the output tensor reference
23-
* @return utils::uvec3 The global workgroup size
2420
*/
2521
utils::uvec3 default_pick_global_wg_size(
2622
ComputeGraph* graph,
2723
const vkapi::ShaderInfo& shader,
2824
const std::vector<ArgGroup>& args,
29-
const std::vector<ValueRef>& additional_args);
25+
const std::vector<ValueRef>& resize_args);
3026

3127
/**
3228
* Creates a local workgroup size based on the first output tensor in the args.
3329
* This is a utility function that extracts the output tensor from
3430
* args.at(0).refs.at(0) and calls graph->create_local_wg_size(out) on it.
35-
*
36-
* @param graph The ComputeGraph instance
37-
* @param args Vector of ArgGroup containing the output tensor reference
38-
* @return utils::uvec3 The local workgroup size
3931
*/
4032
utils::uvec3 default_pick_local_wg_size(
4133
ComputeGraph* graph,
4234
const vkapi::ShaderInfo& shader,
4335
const utils::uvec3& global_workgroup_size,
4436
const std::vector<ArgGroup>& args,
45-
const std::vector<ValueRef>& additional_args);
37+
const std::vector<ValueRef>& resize_args);
4638

4739
} // namespace vkcompute

0 commit comments

Comments
 (0)