Skip to content

Commit b1f74b6

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Reduce conv2d_pw global wg size (#4045)
Summary: Pull Request resolved: #4045 Since our conv2d_pw shader uses 2x2 tiles over the output, we can reduce the number of shader invocations, i.e., reduce the global workgroup size. ghstack-source-id: 231375285 Reviewed By: SS-JIA Differential Revision: D58956301 fbshipit-source-id: 34510c4584b5cf343d6a0c62d3fe42d27ec934e6
1 parent 16941c9 commit b1f74b6

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ void main() {
6666
// | pos[2] | pos[3] |
6767
// +--------+--------+
6868
ivec3 pos[${TILE_SIZE * TILE_SIZE}];
69-
for (int y = 0, i = 0; y < 2; ++y) {
70-
for (int x = 0; x < 2; ++x) {
69+
for (int y = 0, i = 0; y < ${TILE_SIZE}; ++y) {
70+
for (int x = 0; x < ${TILE_SIZE}; ++x) {
7171
pos[i] = ivec3(
72-
gpos.x * 2 + x, gpos.y * ${TILE_SIZE} + y, gpos.z);
72+
gpos.x * ${TILE_SIZE} + x, gpos.y * ${TILE_SIZE} + y, gpos.z);
7373
i++;
7474
}
7575
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,21 @@ Conv2dMethod get_conv2d_method(
289289
return Conv2dMethod::SlidingWindow;
290290
}
291291

292+
api::utils::uvec3 create_conv2d_global_wg_size(
293+
ComputeGraph& graph,
294+
const Conv2dMethod method,
295+
const ValueRef out) {
296+
if (method == Conv2dMethod::Pointwise) {
297+
const api::utils::uvec3 image_extents = graph.image_extents_of(out);
298+
return {
299+
api::utils::div_up(image_extents.data[0u], 2u),
300+
api::utils::div_up(image_extents.data[1u], 2u),
301+
image_extents.data[2u]};
302+
} else {
303+
return graph.create_global_wg_size(out);
304+
}
305+
}
306+
292307
void add_conv2d_node(
293308
ComputeGraph& graph,
294309
const ValueRef in,
@@ -357,7 +372,7 @@ void add_conv2d_node(
357372
graph.execute_nodes().emplace_back(new ExecuteNode(
358373
graph,
359374
shader,
360-
graph.create_global_wg_size(out),
375+
create_conv2d_global_wg_size(graph, method, out),
361376
graph.create_local_wg_size(out),
362377
// Inputs and Outputs
363378
{{out, api::MemoryAccessType::WRITE},

backends/vulkan/test/op_tests/cases.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,17 @@ def get_conv_inputs():
312312
[0],
313313
5,
314314
),
315+
(
316+
(1, 16, 672, 512),
317+
(64, 16, 1, 1),
318+
(64,),
319+
[1, 1],
320+
[0, 0],
321+
[1, 1],
322+
False,
323+
[0, 0],
324+
1,
325+
),
315326
]
316327
)
317328
return test_suite

0 commit comments

Comments
 (0)