Skip to content

Commit 56cacd2

Browse files
committed
[ET-VK] Reduce conv2d_pw global wg size
Since our conv2d_pw shader uses 2x2 tiles over the output, we can reduce the number of shader invocations, i.e., reduce the global workgroup size. Differential Revision: [D58956301](https://our.internmc.facebook.com/intern/diff/D58956301/) ghstack-source-id: 231375285 Pull Request resolved: #4045
1 parent 398ce66 commit 56cacd2

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ void main() {
6666
// | pos[2] | pos[3] |
6767
// +--------+--------+
6868
ivec3 pos[${TILE_SIZE * TILE_SIZE}];
69-
for (int y = 0, i = 0; y < 2; ++y) {
70-
for (int x = 0; x < 2; ++x) {
69+
for (int y = 0, i = 0; y < ${TILE_SIZE}; ++y) {
70+
for (int x = 0; x < ${TILE_SIZE}; ++x) {
7171
pos[i] = ivec3(
72-
gpos.x * 2 + x, gpos.y * ${TILE_SIZE} + y, gpos.z);
72+
gpos.x * ${TILE_SIZE} + x, gpos.y * ${TILE_SIZE} + y, gpos.z);
7373
i++;
7474
}
7575
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,21 @@ Conv2dMethod get_conv2d_method(
289289
return Conv2dMethod::SlidingWindow;
290290
}
291291

292+
api::utils::uvec3 create_conv2d_global_wg_size(
293+
ComputeGraph& graph,
294+
const Conv2dMethod method,
295+
const ValueRef out) {
296+
if (method == Conv2dMethod::Pointwise) {
297+
const api::utils::uvec3 image_extents = graph.image_extents_of(out);
298+
return {
299+
api::utils::div_up(image_extents.data[0u], 2u),
300+
api::utils::div_up(image_extents.data[1u], 2u),
301+
image_extents.data[2u]};
302+
} else {
303+
return graph.create_global_wg_size(out);
304+
}
305+
}
306+
292307
void add_conv2d_node(
293308
ComputeGraph& graph,
294309
const ValueRef in,
@@ -357,7 +372,7 @@ void add_conv2d_node(
357372
graph.execute_nodes().emplace_back(new ExecuteNode(
358373
graph,
359374
shader,
360-
graph.create_global_wg_size(out),
375+
create_conv2d_global_wg_size(graph, method, out),
361376
graph.create_local_wg_size(out),
362377
// Inputs and Outputs
363378
{{out, api::MemoryAccessType::WRITE},

backends/vulkan/test/op_tests/cases.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,17 @@ def get_conv_inputs():
308308
[0],
309309
5,
310310
),
311+
(
312+
(1, 16, 672, 512),
313+
(64, 16, 1, 1),
314+
(64,),
315+
[1, 1],
316+
[0, 0],
317+
[1, 1],
318+
False,
319+
[0, 0],
320+
1,
321+
),
311322
]
312323
)
313324
return test_suite

0 commit comments

Comments
 (0)