Skip to content

Commit ce48e0d

Browse files
[ET-VK] Minor dispatch improvement to conv2d dw op to improve performance. (#11495)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11477 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/113/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/113/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/112/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/113/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 93ea14f commit ce48e0d

File tree

3 files changed

+15
-16
lines changed

3 files changed

+15
-16
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,26 +60,25 @@ void main() {
6060
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_xy_scaled.x;
6161
ivec3 pos = ivec3(
6262
gl_GlobalInvocationID.x % out_limits_xy_scaled.x,
63-
div_by_x % out_limits_xy_scaled.y,
64-
div_by_x / out_limits_xy_scaled.y);
65-
66-
// scale pos.xy by batch sizes, because that's the top pixel to be processed
67-
pos.x *= BATCH_SIZE_X;
68-
pos.y *= BATCH_SIZE_Y;
63+
div_by_x,
64+
gl_GlobalInvocationID.y);
6965

7066
// do not process if top pixel does not fit within the output range
71-
if (pos.z >= out_limits.z) {
67+
if (pos.y >= out_limits_xy_scaled.y || pos.z >= out_limits.z) {
7268
return;
7369
}
7470

71+
// scale pos.xy by batch sizes, because that's the top pixel to be processed
72+
pos.x *= BATCH_SIZE_X;
73+
pos.y *= BATCH_SIZE_Y;
74+
7575
// Compute the index of the top-left element of the overlay region. Negative
7676
// indices indicate that the top-left element is in a region added by padding.
7777
const ivec2 ipos = pos.xy * stride - padding;
7878

7979
// Compute the start and end of the input indices to load. Padding is assumed
8080
// to be constant 0 padding, so any reads from the padding region is skipped.
8181
const ivec2 start = ipos;
82-
const ivec2 end = ipos + overlay_region.xy;
8382

8483
// sum outputs
8584
VEC4_T sum[BATCH_SIZE_Y * BATCH_SIZE_X];

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_sned_output_tile.glsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@ void main() {
5050
const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x;
5151
const ivec3 pos = ivec3(
5252
gl_GlobalInvocationID.x % out_limits.x,
53-
div_by_x % out_limits.y,
54-
div_by_x / out_limits.y);
53+
div_by_x,
54+
gl_GlobalInvocationID.y);
5555

56-
if (pos.z >= out_limits.z) {
56+
// do not process if top pixel does not fit within the output range
57+
if (pos.y >= out_limits.y || pos.z >= out_limits.z) {
5758
return;
5859
}
5960

@@ -64,7 +65,6 @@ void main() {
6465
// Compute the start and end of the input indices to load. Padding is assumed
6566
// to be constant 0 padding, so any reads from the padding region is skipped.
6667
const ivec2 start = ipos;
67-
const ivec2 end = ipos + overlay_region.xy;
6868

6969
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
7070
int kx = 0;

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,11 @@ void add_conv2d_node(
407407
utils::uvec3 wg_size = create_conv2d_global_wg_size(
408408
graph, method, out, weight_data, stride_equals_dilation);
409409

410-
if (method == Conv2dMethod::Depthwise) {
411-
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
412-
} else if (method == Conv2dMethod::Pointwise) {
410+
utils::uvec3 local_wg_size;
411+
if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) {
413412
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
414413
}
415414

416-
utils::uvec3 local_wg_size;
417415
if (method == Conv2dMethod::Pointwise) {
418416
uint32_t local_wg_size_y = 1;
419417
if (wg_size[1] % 8 == 0) {
@@ -424,6 +422,8 @@ void add_conv2d_node(
424422
local_wg_size_y = 2;
425423
}
426424
local_wg_size = {64 / local_wg_size_y, local_wg_size_y, 1};
425+
} else if (method == Conv2dMethod::Depthwise) {
426+
local_wg_size = {64, 1, 1};
427427
} else {
428428
local_wg_size = graph.create_local_wg_size(wg_size);
429429
}

0 commit comments

Comments
 (0)