Skip to content

Commit 40924c4

Browse files
committed
[ET-VK] Migrate ops to use DynamicDispatchNode
## Changes * Migrate operators that are used in the llama model to use `DynamicDispatchNode` instead of `DispatchNode` ## Motivation `DynamicDispatchNode` is a subclass of `DispatchNode` that allows dynamic selection of compute shaders, global and local work group sizing whenever the command buffer is encoded. This is critical for ensuring optimum performance when input shapes are dynamic, since it allows operators to select the best compute shader for the input conditions and also to adjust global work group sizing to launch the minimum number of work groups necessary. Without this change, performance of llama 3.2 1B with dynamic shapes enabled is terrible (< 1 tok/s) because global work group sizing is determined based on maximum tensor sizes, which is based on the maximum sequence length. In practice, the sequence length dimension of tensors (even during the prefill phase) will not approach the maximum. This results in a lot of inactive threads launched during compute shader dispatches. Differential Revision: [D75878398](https://our.internmc.facebook.com/intern/diff/D75878398/) ghstack-source-id: 287884655 Pull Request resolved: #11312
1 parent 4348319 commit 40924c4

File tree

13 files changed

+456
-230
lines changed

13 files changed

+456
-230
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,15 @@ ValueRef ComputeGraph::add_symint(const int32_t val) {
449449
return idx;
450450
}
451451

452+
ValueRef ComputeGraph::get_or_add_value_for_int(const int64_t val) {
453+
for (int i = 0; i < values_.size(); ++i) {
454+
if (values_.at(i).isInt() && values_.at(i).toInt() == val) {
455+
return i;
456+
}
457+
}
458+
return add_scalar(val);
459+
}
460+
452461
ValueRef ComputeGraph::set_input_tensor(
453462
const ValueRef idx,
454463
const bool use_staging) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,13 @@ class ComputeGraph final {
604604

605605
ValueRef add_symint(const int32_t val);
606606

607+
/*
608+
* Searches the graph's value list for a Int value with the specified value.
609+
* If one is found, returns the index of the value. Otherwise, add a new value
610+
* and return the index of the new value.
611+
*/
612+
ValueRef get_or_add_value_for_int(const int64_t val);
613+
607614
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
608615
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
609616

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
@@ -30,8 +31,8 @@ void check_binary_op_args(
3031
void resize_binary_op_node(
3132
ComputeGraph* graph,
3233
const std::vector<ArgGroup>& args,
33-
const std::vector<ValueRef>& extra_args) {
34-
(void)extra_args;
34+
const std::vector<ValueRef>& resize_args) {
35+
(void)resize_args;
3536
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
3637

3738
// TODO(T183442143): Verify tensors are broadcastable.
@@ -78,11 +79,11 @@ void add_binary_op_texture_node(
7879
add_storage_type_suffix(kernel_name, *t_out);
7980
add_dtype_suffix(kernel_name, *t_out);
8081

81-
graph.execute_nodes().emplace_back(new DispatchNode(
82+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
8283
graph,
8384
VK_KERNEL_FROM_STR(kernel_name),
84-
graph.create_global_wg_size(out),
85-
graph.create_local_wg_size(out),
85+
default_pick_global_wg_size,
86+
default_pick_local_wg_size,
8687
// Inputs and Outputs
8788
{{out, vkapi::kWrite}, {{arg1, arg2}, vkapi::kRead}},
8889
// Shader params buffers
@@ -122,11 +123,11 @@ void add_binary_op_buffer_node(
122123
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
123124
add_dtype_suffix(kernel_name, graph.dtype_of(out));
124125

125-
graph.execute_nodes().emplace_back(new DispatchNode(
126+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
126127
graph,
127128
VK_KERNEL_FROM_STR(kernel_name),
128-
graph.create_global_wg_size(out),
129-
graph.create_local_wg_size(out),
129+
default_pick_global_wg_size,
130+
default_pick_local_wg_size,
130131
// Inputs and Outputs
131132
{{out, vkapi::kWrite}, {{in1, in2}, vkapi::kRead}},
132133
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
1415

1516
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
@@ -21,8 +22,8 @@ namespace vkcompute {
2122
void resize_clone_node(
2223
ComputeGraph* graph,
2324
const std::vector<ArgGroup>& args,
24-
const std::vector<ValueRef>& extra_args) {
25-
(void)extra_args;
25+
const std::vector<ValueRef>& resize_args) {
26+
(void)resize_args;
2627
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
2728
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
2829
// TODO: support for when dimensionality doesn't match, i.e. clone is used to
@@ -41,11 +42,11 @@ void add_clone_node(
4142
std::string kernel_name = "clone";
4243
add_dtype_suffix(kernel_name, *t_out);
4344

44-
graph.execute_nodes().emplace_back(new DispatchNode(
45+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
4546
graph,
4647
VK_KERNEL_FROM_STR(kernel_name),
47-
graph.create_global_wg_size(out),
48-
graph.create_local_wg_size(out),
48+
default_pick_global_wg_size,
49+
default_pick_local_wg_size,
4950
// Inputs and Outputs
5051
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
5152
// Parameter Buffers
@@ -68,12 +69,11 @@ void add_image_to_buffer_node(
6869
add_dtype_suffix(kernel_name, graph.dtype_of(image));
6970
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
7071

71-
utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
72-
graph.execute_nodes().emplace_back(new DispatchNode(
72+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
7373
graph,
7474
shader,
75-
global_wg_size,
76-
graph.create_local_wg_size(global_wg_size),
75+
default_pick_global_wg_size,
76+
default_pick_local_wg_size,
7777
// Input and Outputs
7878
{{buffer, vkapi::kWrite}, {image, vkapi::kRead}},
7979
// Parameter Buffers
@@ -96,12 +96,11 @@ void add_buffer_to_image_node(
9696
add_dtype_suffix(kernel_name, graph.dtype_of(image));
9797
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
9898

99-
utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
100-
graph.execute_nodes().emplace_back(new DispatchNode(
99+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
101100
graph,
102101
shader,
103-
global_wg_size,
104-
graph.create_local_wg_size(global_wg_size),
102+
default_pick_global_wg_size,
103+
default_pick_local_wg_size,
105104
// Input and Outputs
106105
{{image, vkapi::kWrite}, {buffer, vkapi::kRead}},
107106
// Parameter Buffers

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 100 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1314

@@ -37,12 +38,12 @@ void check_matmul_args(
3738
void resize_matmul_node(
3839
ComputeGraph* graph,
3940
const std::vector<ArgGroup>& args,
40-
const std::vector<ValueRef>& extra_args) {
41+
const std::vector<ValueRef>& resize_args) {
4142
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
4243
vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]);
4344
vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]);
4445

45-
bool mat2_is_transposed = graph->get_bool(extra_args[0]);
46+
bool mat2_is_transposed = graph->get_bool(resize_args[0]);
4647

4748
const int out_cols = utils::val_at(-2, mat1->sizes());
4849
const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2->sizes())
@@ -56,6 +57,22 @@ void resize_matmul_node(
5657
out->virtual_resize(new_out_sizes);
5758
}
5859

60+
/**
61+
* Custom global workgroup size function for naive buffer matmul operations.
62+
*/
63+
utils::uvec3 matmul_naive_buffer_global_wg_size(
64+
ComputeGraph* graph,
65+
const vkapi::ShaderInfo& shader,
66+
const std::vector<ArgGroup>& args,
67+
const std::vector<ValueRef>& resize_args) {
68+
(void)shader;
69+
const ValueRef out = args.at(0).refs.at(0);
70+
return {
71+
graph->size_at<uint32_t>(-1, out),
72+
graph->size_at<uint32_t>(-2, out),
73+
graph->size_at<uint32_t>(-3, out) * graph->size_at<uint32_t>(-4, out)};
74+
}
75+
5976
void add_matmul_naive_buffer_node(
6077
ComputeGraph& graph,
6178
const ValueRef mat1,
@@ -72,21 +89,16 @@ void add_matmul_naive_buffer_node(
7289
std::string kernel_name = "matmul_naive_buffer";
7390
add_dtype_suffix(kernel_name, graph.dtype_of(out));
7491

75-
utils::uvec3 global_size = {
76-
graph.size_at<uint32_t>(-1, out),
77-
graph.size_at<uint32_t>(-2, out),
78-
graph.size_at<uint32_t>(-3, out) * graph.size_at<uint32_t>(-4, out)};
79-
8092
int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
8193
graph.get_bool(mat2_is_transposed))
8294
? 1
8395
: 0;
8496

85-
graph.execute_nodes().emplace_back(new DispatchNode(
97+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
8698
graph,
8799
VK_KERNEL_FROM_STR(kernel_name),
88-
global_size,
89-
graph.create_local_wg_size(global_size),
100+
matmul_naive_buffer_global_wg_size,
101+
default_pick_local_wg_size,
90102
// Inputs and Outputs
91103
{{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}},
92104
// Shader params buffers
@@ -109,6 +121,22 @@ void add_matmul_naive_buffer_node(
109121
resize_matmul_node));
110122
}
111123

124+
vkapi::ShaderInfo pick_matmul_naive_texture3d_shader(
125+
ComputeGraph* graph,
126+
const std::vector<ArgGroup>& args,
127+
const std::vector<ValueRef>& resize_args) {
128+
const ValueRef out = args.at(0).refs.at(0);
129+
const bool is_transposed = graph->get_bool(resize_args.at(0));
130+
131+
std::string kernel_name =
132+
is_transposed ? "matmul_transposed_naive" : "matmul_naive";
133+
kernel_name.reserve(kShaderNameReserve);
134+
add_storage_type_suffix(kernel_name, graph->storage_type_of(out));
135+
add_dtype_suffix(kernel_name, graph->dtype_of(out));
136+
137+
return VK_KERNEL_FROM_STR(kernel_name);
138+
}
139+
112140
void add_matmul_naive_texture3d_node(
113141
ComputeGraph& graph,
114142
const ValueRef mat1,
@@ -122,19 +150,11 @@ void add_matmul_naive_texture3d_node(
122150
utils::kHeightPacked,
123151
/*passthrough = */ true);
124152

125-
std::string kernel_name = graph.get_bool(mat2_is_transposed)
126-
? "matmul_transposed_naive"
127-
: "matmul_naive";
128-
kernel_name.reserve(kShaderNameReserve);
129-
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
130-
add_dtype_suffix(kernel_name, graph.dtype_of(out));
131-
132-
utils::uvec3 global_wg_size = graph.logical_limits_of(out);
133-
graph.execute_nodes().emplace_back(new DispatchNode(
153+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
134154
graph,
135-
VK_KERNEL_FROM_STR(kernel_name),
136-
global_wg_size,
137-
graph.create_local_wg_size(global_wg_size),
155+
pick_matmul_naive_texture3d_shader,
156+
default_pick_global_wg_size,
157+
default_pick_local_wg_size,
138158
// Inputs and Outputs
139159
{{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}},
140160
// Shader params buffers
@@ -156,6 +176,59 @@ void add_matmul_naive_texture3d_node(
156176
resize_matmul_node));
157177
}
158178

179+
vkapi::ShaderInfo pick_matmul_optimized_shader(
180+
ComputeGraph* graph,
181+
const std::vector<ArgGroup>& args,
182+
const std::vector<ValueRef>& resize_args) {
183+
const ValueRef out = args.at(0).refs.at(0);
184+
const ValueRef mat1_W_packed = resize_args.at(1);
185+
const bool mat2_is_transposed_val = graph->get_bool(resize_args.at(0));
186+
187+
std::string kernel_name = mat2_is_transposed_val
188+
? "matmul_transposed_optimized"
189+
: "matmul_optimized";
190+
191+
std::vector<int64_t> mat1_sizes = graph->sizes_of(mat1_W_packed);
192+
int mat1_dims = mat1_sizes.size();
193+
if (mat1_dims == 3) {
194+
kernel_name = "batch_" + kernel_name;
195+
}
196+
if (mat1_sizes.at(mat1_dims - 2) < 8) {
197+
kernel_name += "_tile_row_2";
198+
} else {
199+
kernel_name += "_tile_row_4";
200+
}
201+
202+
add_dtype_suffix(kernel_name, graph->dtype_of(out));
203+
204+
return VK_KERNEL_FROM_STR(kernel_name);
205+
}
206+
207+
utils::uvec3 matmul_optimized_global_wg_size(
208+
ComputeGraph* graph,
209+
const vkapi::ShaderInfo& shader,
210+
const std::vector<ArgGroup>& args,
211+
const std::vector<ValueRef>& resize_args) {
212+
(void)shader;
213+
214+
const ValueRef out = args.at(0).refs.at(0);
215+
const ValueRef mat1_W_packed = resize_args.at(1);
216+
217+
const std::vector<int64_t> mat1_sizes = graph->sizes_of(mat1_W_packed);
218+
const int mat1_dims = mat1_sizes.size();
219+
220+
utils::uvec3 global_size = graph->logical_limits_of(out);
221+
if (mat1_sizes.at(mat1_dims - 2) < 8) {
222+
// Use `logical_extents` instead of `image_extents` because the workgroup
223+
// axes need to correspond to tensor dimensions.
224+
global_size = utils::divup_vec(global_size, {4, 2, 1});
225+
} else {
226+
global_size = utils::divup_vec(global_size, {4, 4, 1});
227+
}
228+
229+
return global_size;
230+
}
231+
159232
void add_matmul_optimized_node(
160233
ComputeGraph& graph,
161234
const ValueRef mat1,
@@ -192,45 +265,11 @@ void add_matmul_optimized_node(
192265
viewFn(graph, {mat2, graph.add_none(), mat2_packed});
193266
}
194267

195-
std::string kernel_name = mat2_is_transposed_val
196-
? "matmul_transposed_optimized"
197-
: "matmul_optimized";
198-
199-
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
200-
int mat1_dims = mat1_sizes.size();
201-
if (mat1_dims == 3) {
202-
kernel_name = "batch_" + kernel_name;
203-
}
204-
if (mat1_sizes.at(mat1_dims - 2) < 8) {
205-
kernel_name += "_tile_row_2";
206-
} else {
207-
kernel_name += "_tile_row_4";
208-
}
209-
210-
add_dtype_suffix(kernel_name, graph.dtype_of(out));
211-
212-
// Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the
213-
// total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is
214-
// channels packed, C does not need to be divided by 4. The "identity" of each
215-
// thread is the (x, y, z) coordinate of the output tile it is computing, and
216-
// this identity can be used to compute the tensor index of the top left
217-
// element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0]
218-
utils::uvec3 global_size = graph.logical_limits_of(out);
219-
if (mat1_sizes.at(mat1_dims - 2) < 8) {
220-
// Use `logical_extents` instead of `image_extents` because the workgroup
221-
// axes need to correspond to tensor dimensions.
222-
global_size = utils::divup_vec(global_size, {4, 2, 1});
223-
} else {
224-
global_size = utils::divup_vec(global_size, {4, 4, 1});
225-
}
226-
227-
utils::uvec3 local_size = adaptive_work_group_size(global_size);
228-
229-
graph.execute_nodes().emplace_back(new DispatchNode(
268+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
230269
graph,
231-
VK_KERNEL_FROM_STR(kernel_name),
232-
global_size,
233-
local_size,
270+
pick_matmul_optimized_shader,
271+
matmul_optimized_global_wg_size,
272+
default_pick_local_wg_size,
234273
// Inputs and Outputs
235274
{{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed}, vkapi::kRead}},
236275
// Shader params buffers
@@ -246,7 +285,7 @@ void add_matmul_optimized_node(
246285
graph.hashed_layout_of(mat1_W_packed),
247286
graph.hashed_layout_of(mat2_packed)},
248287
// Resize Args
249-
{mat2_is_transposed},
288+
{mat2_is_transposed, mat1_W_packed},
250289
// Resizing Logic
251290
resize_matmul_node));
252291
}

0 commit comments

Comments
 (0)