8
8
9
9
#include < executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
10
11
+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
11
12
#include < executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.h>
12
13
#include < executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13
14
@@ -37,12 +38,12 @@ void check_matmul_args(
37
38
void resize_matmul_node (
38
39
ComputeGraph* graph,
39
40
const std::vector<ArgGroup>& args,
40
- const std::vector<ValueRef>& extra_args ) {
41
+ const std::vector<ValueRef>& resize_args ) {
41
42
vTensorPtr out = graph->get_tensor (args[0 ].refs [0 ]);
42
43
vTensorPtr mat1 = graph->get_tensor (args[1 ].refs [0 ]);
43
44
vTensorPtr mat2 = graph->get_tensor (args[1 ].refs [1 ]);
44
45
45
- bool mat2_is_transposed = graph->get_bool (extra_args [0 ]);
46
+ bool mat2_is_transposed = graph->get_bool (resize_args [0 ]);
46
47
47
48
const int out_cols = utils::val_at (-2 , mat1->sizes ());
48
49
const int out_rows = mat2_is_transposed ? utils::val_at (-2 , mat2->sizes ())
@@ -56,6 +57,23 @@ void resize_matmul_node(
56
57
out->virtual_resize (new_out_sizes);
57
58
}
58
59
60
+ /* *
61
+ * Custom global workgroup size function for naive buffer matmul operations.
62
+ */
63
+ utils::uvec3 matmul_naive_buffer_global_wg_size (
64
+ ComputeGraph* graph,
65
+ const vkapi::ShaderInfo& shader,
66
+ const std::vector<ArgGroup>& args,
67
+ const std::vector<ValueRef>& resize_args) {
68
+ (void )shader;
69
+ (void )resize_args;
70
+ const ValueRef out = args.at (0 ).refs .at (0 );
71
+ return {
72
+ graph->size_at <uint32_t >(-1 , out),
73
+ graph->size_at <uint32_t >(-2 , out),
74
+ graph->size_at <uint32_t >(-3 , out) * graph->size_at <uint32_t >(-4 , out)};
75
+ }
76
+
59
77
void add_matmul_naive_buffer_node (
60
78
ComputeGraph& graph,
61
79
const ValueRef mat1,
@@ -72,21 +90,16 @@ void add_matmul_naive_buffer_node(
72
90
std::string kernel_name = " matmul_naive_buffer" ;
73
91
add_dtype_suffix (kernel_name, graph.dtype_of (out));
74
92
75
- utils::uvec3 global_size = {
76
- graph.size_at <uint32_t >(-1 , out),
77
- graph.size_at <uint32_t >(-2 , out),
78
- graph.size_at <uint32_t >(-3 , out) * graph.size_at <uint32_t >(-4 , out)};
79
-
80
93
int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
81
94
graph.get_bool (mat2_is_transposed))
82
95
? 1
83
96
: 0 ;
84
97
85
- graph.execute_nodes ().emplace_back (new DispatchNode (
98
+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
86
99
graph,
87
100
VK_KERNEL_FROM_STR (kernel_name),
88
- global_size ,
89
- graph. create_local_wg_size (global_size) ,
101
+ matmul_naive_buffer_global_wg_size ,
102
+ default_pick_local_wg_size ,
90
103
// Inputs and Outputs
91
104
{{out, vkapi::kWrite }, {{mat1, mat2}, vkapi::kRead }},
92
105
// Shader params buffers
@@ -109,6 +122,22 @@ void add_matmul_naive_buffer_node(
109
122
resize_matmul_node));
110
123
}
111
124
125
+ vkapi::ShaderInfo pick_matmul_naive_texture3d_shader (
126
+ ComputeGraph* graph,
127
+ const std::vector<ArgGroup>& args,
128
+ const std::vector<ValueRef>& resize_args) {
129
+ const ValueRef out = args.at (0 ).refs .at (0 );
130
+ const bool is_transposed = graph->get_bool (resize_args.at (0 ));
131
+
132
+ std::string kernel_name =
133
+ is_transposed ? " matmul_transposed_naive" : " matmul_naive" ;
134
+ kernel_name.reserve (kShaderNameReserve );
135
+ add_storage_type_suffix (kernel_name, graph->storage_type_of (out));
136
+ add_dtype_suffix (kernel_name, graph->dtype_of (out));
137
+
138
+ return VK_KERNEL_FROM_STR (kernel_name);
139
+ }
140
+
112
141
void add_matmul_naive_texture3d_node (
113
142
ComputeGraph& graph,
114
143
const ValueRef mat1,
@@ -122,19 +151,11 @@ void add_matmul_naive_texture3d_node(
122
151
utils::kHeightPacked ,
123
152
/* passthrough = */ true );
124
153
125
- std::string kernel_name = graph.get_bool (mat2_is_transposed)
126
- ? " matmul_transposed_naive"
127
- : " matmul_naive" ;
128
- kernel_name.reserve (kShaderNameReserve );
129
- add_storage_type_suffix (kernel_name, graph.storage_type_of (out));
130
- add_dtype_suffix (kernel_name, graph.dtype_of (out));
131
-
132
- utils::uvec3 global_wg_size = graph.logical_limits_of (out);
133
- graph.execute_nodes ().emplace_back (new DispatchNode (
154
+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
134
155
graph,
135
- VK_KERNEL_FROM_STR (kernel_name) ,
136
- global_wg_size ,
137
- graph. create_local_wg_size (global_wg_size) ,
156
+ pick_matmul_naive_texture3d_shader ,
157
+ default_pick_global_wg_size ,
158
+ default_pick_local_wg_size ,
138
159
// Inputs and Outputs
139
160
{{out, vkapi::kWrite }, {{mat1, mat2}, vkapi::kRead }},
140
161
// Shader params buffers
@@ -156,6 +177,59 @@ void add_matmul_naive_texture3d_node(
156
177
resize_matmul_node));
157
178
}
158
179
180
+ vkapi::ShaderInfo pick_matmul_optimized_shader (
181
+ ComputeGraph* graph,
182
+ const std::vector<ArgGroup>& args,
183
+ const std::vector<ValueRef>& resize_args) {
184
+ const ValueRef out = args.at (0 ).refs .at (0 );
185
+ const ValueRef mat1_W_packed = resize_args.at (1 );
186
+ const bool mat2_is_transposed_val = graph->get_bool (resize_args.at (0 ));
187
+
188
+ std::string kernel_name = mat2_is_transposed_val
189
+ ? " matmul_transposed_optimized"
190
+ : " matmul_optimized" ;
191
+
192
+ std::vector<int64_t > mat1_sizes = graph->sizes_of (mat1_W_packed);
193
+ size_t mat1_dims = mat1_sizes.size ();
194
+ if (mat1_dims == 3 ) {
195
+ kernel_name = " batch_" + kernel_name;
196
+ }
197
+ if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
198
+ kernel_name += " _tile_row_2" ;
199
+ } else {
200
+ kernel_name += " _tile_row_4" ;
201
+ }
202
+
203
+ add_dtype_suffix (kernel_name, graph->dtype_of (out));
204
+
205
+ return VK_KERNEL_FROM_STR (kernel_name);
206
+ }
207
+
208
+ utils::uvec3 matmul_optimized_global_wg_size (
209
+ ComputeGraph* graph,
210
+ const vkapi::ShaderInfo& shader,
211
+ const std::vector<ArgGroup>& args,
212
+ const std::vector<ValueRef>& resize_args) {
213
+ (void )shader;
214
+
215
+ const ValueRef out = args.at (0 ).refs .at (0 );
216
+ const ValueRef mat1_W_packed = resize_args.at (1 );
217
+
218
+ const std::vector<int64_t > mat1_sizes = graph->sizes_of (mat1_W_packed);
219
+ const size_t mat1_dims = mat1_sizes.size ();
220
+
221
+ utils::uvec3 global_size = graph->logical_limits_of (out);
222
+ if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
223
+ // Use `logical_extents` instead of `image_extents` because the workgroup
224
+ // axes need to correspond to tensor dimensions.
225
+ global_size = utils::divup_vec (global_size, {4 , 2 , 1 });
226
+ } else {
227
+ global_size = utils::divup_vec (global_size, {4 , 4 , 1 });
228
+ }
229
+
230
+ return global_size;
231
+ }
232
+
159
233
void add_matmul_optimized_node (
160
234
ComputeGraph& graph,
161
235
const ValueRef mat1,
@@ -192,45 +266,11 @@ void add_matmul_optimized_node(
192
266
viewFn (graph, {mat2, graph.add_none (), mat2_packed});
193
267
}
194
268
195
- std::string kernel_name = mat2_is_transposed_val
196
- ? " matmul_transposed_optimized"
197
- : " matmul_optimized" ;
198
-
199
- std::vector<int64_t > mat1_sizes = graph.sizes_of (mat1_W_packed);
200
- int mat1_dims = mat1_sizes.size ();
201
- if (mat1_dims == 3 ) {
202
- kernel_name = " batch_" + kernel_name;
203
- }
204
- if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
205
- kernel_name += " _tile_row_2" ;
206
- } else {
207
- kernel_name += " _tile_row_4" ;
208
- }
209
-
210
- add_dtype_suffix (kernel_name, graph.dtype_of (out));
211
-
212
- // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the
213
- // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is
214
- // channels packed, C does not need to be divided by 4. The "identity" of each
215
- // thread is the (x, y, z) coordinate of the output tile it is computing, and
216
- // this identity can be used to compute the tensor index of the top left
217
- // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0]
218
- utils::uvec3 global_size = graph.logical_limits_of (out);
219
- if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
220
- // Use `logical_extents` instead of `image_extents` because the workgroup
221
- // axes need to correspond to tensor dimensions.
222
- global_size = utils::divup_vec (global_size, {4 , 2 , 1 });
223
- } else {
224
- global_size = utils::divup_vec (global_size, {4 , 4 , 1 });
225
- }
226
-
227
- utils::uvec3 local_size = adaptive_work_group_size (global_size);
228
-
229
- graph.execute_nodes ().emplace_back (new DispatchNode (
269
+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
230
270
graph,
231
- VK_KERNEL_FROM_STR (kernel_name) ,
232
- global_size ,
233
- local_size ,
271
+ pick_matmul_optimized_shader ,
272
+ matmul_optimized_global_wg_size ,
273
+ default_pick_local_wg_size ,
234
274
// Inputs and Outputs
235
275
{{out, vkapi::kWrite }, {{mat1_W_packed, mat2_packed}, vkapi::kRead }},
236
276
// Shader params buffers
@@ -246,7 +286,7 @@ void add_matmul_optimized_node(
246
286
graph.hashed_layout_of (mat1_W_packed),
247
287
graph.hashed_layout_of (mat2_packed)},
248
288
// Resize Args
249
- {mat2_is_transposed},
289
+ {mat2_is_transposed, mat1_W_packed },
250
290
// Resizing Logic
251
291
resize_matmul_node));
252
292
}
0 commit comments