Skip to content

Commit 99284c7

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Migrate convolution workgroup API (#3996)
Summary: Pull Request resolved: #3996 The expectation these days is to always use the `ComputeGraph` API; we've just been lazy with migration. Doing convolution now since I'm playing with the workgroup size. ghstack-source-id: 230551344 bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: nathanaelsee Differential Revision: D58681683 fbshipit-source-id: e789018c59417c4e8c0649e89c87492a01be314a
1 parent ed81c5b commit 99284c7

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,11 @@ ValueRef prepack_biases(
9999

100100
api::ShaderInfo shader = get_nchw_to_tensor_shader(*t);
101101

102-
api::utils::uvec3 global_size = t->image_extents();
103-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
104-
105102
graph.prepack_nodes().emplace_back(new PrepackNode(
106103
graph,
107104
shader,
108-
global_size,
109-
local_size,
105+
graph.create_global_wg_size(v),
106+
graph.create_local_wg_size(v),
110107
vref,
111108
v,
112109
{t->sizes_ubo()},
@@ -203,17 +200,14 @@ ValueRef prepack_weights(
203200
final_sizes, graph.dtype_of(vref), api::kTexture2D, api::kChannelsPacked);
204201
vTensorPtr t = graph.get_tensor(v);
205202

206-
api::utils::uvec3 global_size = t->image_extents();
207-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
208-
209203
api::ShaderInfo shader =
210204
get_conv2d_shader(graph, *t, /*prepack_weights = */ true, method, vref);
211205

212206
graph.prepack_nodes().emplace_back(new PrepackNode(
213207
graph,
214208
shader,
215-
global_size,
216-
local_size,
209+
graph.create_global_wg_size(v),
210+
graph.create_local_wg_size(v),
217211
vref,
218212
v,
219213
{t->sizes_ubo(),
@@ -343,9 +337,6 @@ void add_conv2d_node(
343337
}
344338
check_conv_args(*t_in, *t_out);
345339

346-
api::utils::uvec3 global_size = t_out->image_extents();
347-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
348-
349340
Kernel2dParams kernel_params = create_kernel2d_params(
350341
graph,
351342
weight,
@@ -366,8 +357,8 @@ void add_conv2d_node(
366357
graph.execute_nodes().emplace_back(new ExecuteNode(
367358
graph,
368359
shader,
369-
global_size,
370-
local_size,
360+
graph.create_global_wg_size(out),
361+
graph.create_local_wg_size(out),
371362
// Inputs and Outputs
372363
{{out, api::MemoryAccessType::WRITE},
373364
{{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}},

0 commit comments

Comments
 (0)