8
8
9
9
#include < executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
10
11
+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
11
12
#include < executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.h>
12
13
#include < executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13
14
@@ -37,12 +38,12 @@ void check_matmul_args(
37
38
void resize_matmul_node (
38
39
ComputeGraph* graph,
39
40
const std::vector<ArgGroup>& args,
40
- const std::vector<ValueRef>& extra_args ) {
41
+ const std::vector<ValueRef>& resize_args ) {
41
42
vTensorPtr out = graph->get_tensor (args[0 ].refs [0 ]);
42
43
vTensorPtr mat1 = graph->get_tensor (args[1 ].refs [0 ]);
43
44
vTensorPtr mat2 = graph->get_tensor (args[1 ].refs [1 ]);
44
45
45
- bool mat2_is_transposed = graph->get_bool (extra_args [0 ]);
46
+ bool mat2_is_transposed = graph->get_bool (resize_args [0 ]);
46
47
47
48
const int out_cols = utils::val_at (-2 , mat1->sizes ());
48
49
const int out_rows = mat2_is_transposed ? utils::val_at (-2 , mat2->sizes ())
@@ -56,6 +57,22 @@ void resize_matmul_node(
56
57
out->virtual_resize (new_out_sizes);
57
58
}
58
59
60
+ /* *
61
+ * Custom global workgroup size function for naive buffer matmul operations.
62
+ */
63
+ utils::uvec3 matmul_naive_buffer_global_wg_size (
64
+ ComputeGraph* graph,
65
+ const vkapi::ShaderInfo& shader,
66
+ const std::vector<ArgGroup>& args,
67
+ const std::vector<ValueRef>& resize_args) {
68
+ (void )shader;
69
+ const ValueRef out = args.at (0 ).refs .at (0 );
70
+ return {
71
+ graph->size_at <uint32_t >(-1 , out),
72
+ graph->size_at <uint32_t >(-2 , out),
73
+ graph->size_at <uint32_t >(-3 , out) * graph->size_at <uint32_t >(-4 , out)};
74
+ }
75
+
59
76
void add_matmul_naive_buffer_node (
60
77
ComputeGraph& graph,
61
78
const ValueRef mat1,
@@ -72,21 +89,16 @@ void add_matmul_naive_buffer_node(
72
89
std::string kernel_name = " matmul_naive_buffer" ;
73
90
add_dtype_suffix (kernel_name, graph.dtype_of (out));
74
91
75
- utils::uvec3 global_size = {
76
- graph.size_at <uint32_t >(-1 , out),
77
- graph.size_at <uint32_t >(-2 , out),
78
- graph.size_at <uint32_t >(-3 , out) * graph.size_at <uint32_t >(-4 , out)};
79
-
80
92
int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
81
93
graph.get_bool (mat2_is_transposed))
82
94
? 1
83
95
: 0 ;
84
96
85
- graph.execute_nodes ().emplace_back (new DispatchNode (
97
+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
86
98
graph,
87
99
VK_KERNEL_FROM_STR (kernel_name),
88
- global_size ,
89
- graph. create_local_wg_size (global_size) ,
100
+ matmul_naive_buffer_global_wg_size ,
101
+ default_pick_local_wg_size ,
90
102
// Inputs and Outputs
91
103
{{out, vkapi::kWrite }, {{mat1, mat2}, vkapi::kRead }},
92
104
// Shader params buffers
@@ -109,6 +121,22 @@ void add_matmul_naive_buffer_node(
109
121
resize_matmul_node));
110
122
}
111
123
124
+ vkapi::ShaderInfo pick_matmul_naive_texture3d_shader (
125
+ ComputeGraph* graph,
126
+ const std::vector<ArgGroup>& args,
127
+ const std::vector<ValueRef>& resize_args) {
128
+ const ValueRef out = args.at (0 ).refs .at (0 );
129
+ const bool is_transposed = graph->get_bool (resize_args.at (0 ));
130
+
131
+ std::string kernel_name =
132
+ is_transposed ? " matmul_transposed_naive" : " matmul_naive" ;
133
+ kernel_name.reserve (kShaderNameReserve );
134
+ add_storage_type_suffix (kernel_name, graph->storage_type_of (out));
135
+ add_dtype_suffix (kernel_name, graph->dtype_of (out));
136
+
137
+ return VK_KERNEL_FROM_STR (kernel_name);
138
+ }
139
+
112
140
void add_matmul_naive_texture3d_node (
113
141
ComputeGraph& graph,
114
142
const ValueRef mat1,
@@ -122,19 +150,11 @@ void add_matmul_naive_texture3d_node(
122
150
utils::kHeightPacked ,
123
151
/* passthrough = */ true );
124
152
125
- std::string kernel_name = graph.get_bool (mat2_is_transposed)
126
- ? " matmul_transposed_naive"
127
- : " matmul_naive" ;
128
- kernel_name.reserve (kShaderNameReserve );
129
- add_storage_type_suffix (kernel_name, graph.storage_type_of (out));
130
- add_dtype_suffix (kernel_name, graph.dtype_of (out));
131
-
132
- utils::uvec3 global_wg_size = graph.logical_limits_of (out);
133
- graph.execute_nodes ().emplace_back (new DispatchNode (
153
+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
134
154
graph,
135
- VK_KERNEL_FROM_STR (kernel_name) ,
136
- global_wg_size ,
137
- graph. create_local_wg_size (global_wg_size) ,
155
+ pick_matmul_naive_texture3d_shader ,
156
+ default_pick_global_wg_size ,
157
+ default_pick_local_wg_size ,
138
158
// Inputs and Outputs
139
159
{{out, vkapi::kWrite }, {{mat1, mat2}, vkapi::kRead }},
140
160
// Shader params buffers
@@ -156,6 +176,59 @@ void add_matmul_naive_texture3d_node(
156
176
resize_matmul_node));
157
177
}
158
178
179
+ vkapi::ShaderInfo pick_matmul_optimized_shader (
180
+ ComputeGraph* graph,
181
+ const std::vector<ArgGroup>& args,
182
+ const std::vector<ValueRef>& resize_args) {
183
+ const ValueRef out = args.at (0 ).refs .at (0 );
184
+ const ValueRef mat1_W_packed = resize_args.at (1 );
185
+ const bool mat2_is_transposed_val = graph->get_bool (resize_args.at (0 ));
186
+
187
+ std::string kernel_name = mat2_is_transposed_val
188
+ ? " matmul_transposed_optimized"
189
+ : " matmul_optimized" ;
190
+
191
+ std::vector<int64_t > mat1_sizes = graph->sizes_of (mat1_W_packed);
192
+ int mat1_dims = mat1_sizes.size ();
193
+ if (mat1_dims == 3 ) {
194
+ kernel_name = " batch_" + kernel_name;
195
+ }
196
+ if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
197
+ kernel_name += " _tile_row_2" ;
198
+ } else {
199
+ kernel_name += " _tile_row_4" ;
200
+ }
201
+
202
+ add_dtype_suffix (kernel_name, graph->dtype_of (out));
203
+
204
+ return VK_KERNEL_FROM_STR (kernel_name);
205
+ }
206
+
207
+ utils::uvec3 matmul_optimized_global_wg_size (
208
+ ComputeGraph* graph,
209
+ const vkapi::ShaderInfo& shader,
210
+ const std::vector<ArgGroup>& args,
211
+ const std::vector<ValueRef>& resize_args) {
212
+ (void )shader;
213
+
214
+ const ValueRef out = args.at (0 ).refs .at (0 );
215
+ const ValueRef mat1_W_packed = resize_args.at (1 );
216
+
217
+ const std::vector<int64_t > mat1_sizes = graph->sizes_of (mat1_W_packed);
218
+ const int mat1_dims = mat1_sizes.size ();
219
+
220
+ utils::uvec3 global_size = graph->logical_limits_of (out);
221
+ if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
222
+ // Use `logical_extents` instead of `image_extents` because the workgroup
223
+ // axes need to correspond to tensor dimensions.
224
+ global_size = utils::divup_vec (global_size, {4 , 2 , 1 });
225
+ } else {
226
+ global_size = utils::divup_vec (global_size, {4 , 4 , 1 });
227
+ }
228
+
229
+ return global_size;
230
+ }
231
+
159
232
void add_matmul_optimized_node (
160
233
ComputeGraph& graph,
161
234
const ValueRef mat1,
@@ -192,45 +265,11 @@ void add_matmul_optimized_node(
192
265
viewFn (graph, {mat2, graph.add_none (), mat2_packed});
193
266
}
194
267
195
- std::string kernel_name = mat2_is_transposed_val
196
- ? " matmul_transposed_optimized"
197
- : " matmul_optimized" ;
198
-
199
- std::vector<int64_t > mat1_sizes = graph.sizes_of (mat1_W_packed);
200
- int mat1_dims = mat1_sizes.size ();
201
- if (mat1_dims == 3 ) {
202
- kernel_name = " batch_" + kernel_name;
203
- }
204
- if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
205
- kernel_name += " _tile_row_2" ;
206
- } else {
207
- kernel_name += " _tile_row_4" ;
208
- }
209
-
210
- add_dtype_suffix (kernel_name, graph.dtype_of (out));
211
-
212
- // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the
213
- // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is
214
- // channels packed, C does not need to be divided by 4. The "identity" of each
215
- // thread is the (x, y, z) coordinate of the output tile it is computing, and
216
- // this identity can be used to compute the tensor index of the top left
217
- // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0]
218
- utils::uvec3 global_size = graph.logical_limits_of (out);
219
- if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
220
- // Use `logical_extents` instead of `image_extents` because the workgroup
221
- // axes need to correspond to tensor dimensions.
222
- global_size = utils::divup_vec (global_size, {4 , 2 , 1 });
223
- } else {
224
- global_size = utils::divup_vec (global_size, {4 , 4 , 1 });
225
- }
226
-
227
- utils::uvec3 local_size = adaptive_work_group_size (global_size);
228
-
229
- graph.execute_nodes ().emplace_back (new DispatchNode (
268
+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
230
269
graph,
231
- VK_KERNEL_FROM_STR (kernel_name) ,
232
- global_size ,
233
- local_size ,
270
+ pick_matmul_optimized_shader ,
271
+ matmul_optimized_global_wg_size ,
272
+ default_pick_local_wg_size ,
234
273
// Inputs and Outputs
235
274
{{out, vkapi::kWrite }, {{mat1_W_packed, mat2_packed}, vkapi::kRead }},
236
275
// Shader params buffers
@@ -246,7 +285,7 @@ void add_matmul_optimized_node(
246
285
graph.hashed_layout_of (mat1_W_packed),
247
286
graph.hashed_layout_of (mat2_packed)},
248
287
// Resize Args
249
- {mat2_is_transposed},
288
+ {mat2_is_transposed, mat1_W_packed },
250
289
// Resizing Logic
251
290
resize_matmul_node));
252
291
}
0 commit comments