Skip to content

Commit 6bac017

Browse files
committed
Update on "[ET-VK] Modify quantized linear tiling shader to linearly dispatch work to improve thread occupancy and performance."
This diff changes tiled 8 bit quantized linear mat mul op to linearly dispatch work which increases thread occupancy and improves performance. Differential Revision: [D73751979](https://our.internmc.facebook.com/intern/diff/D73751979/) [ghstack-poisoned]
1 parent 16732cc commit 6bac017

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,21 @@ layout(push_constant) uniform restrict Block {
3838
ivec4 weight_sizes;
3939
};
4040

41+
#include "indexing_utils.h"
42+
4143
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4244

4345
shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS];
4446

4547
void main() {
46-
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
47-
const uint out_col = gl_GlobalInvocationID.x << 2;
48+
const uint out_width_ntexels = divup4(out_sizes.x);
49+
const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2;
50+
const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
4851

4952
const int gid = int(gl_LocalInvocationID.x); // group id
5053
const int wid = int(gl_LocalInvocationID.z); // worker id
5154

52-
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
55+
if (out_row >= out_sizes.y) {
5356
return;
5457
}
5558

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ layout(push_constant) uniform restrict Block {
4141
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4242

4343
void main() {
44-
const uint out_size_x_div_4 = divup4(out_sizes.x);
45-
const uint out_col = (gl_GlobalInvocationID.x % out_size_x_div_4) << 2;
46-
const uint out_row = (gl_GlobalInvocationID.x / out_size_x_div_4) * TILE_ROWS;
44+
const uint out_width_ntexels = divup4(out_sizes.x);
45+
const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2;
46+
const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
4747

4848
if (out_row >= out_sizes.y) {
4949
return;

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,11 @@ void add_q_8w_linear_tiled_node(
195195
out_tile_nrows = 4;
196196
}
197197

198-
utils::uvec3 global_wg_size = graph.logical_limits_of(out);
199-
global_wg_size[1] = global_wg_size[1] / out_tile_nrows;
200-
if (!use_coop_algorithm) {
201-
global_wg_size[0] *= global_wg_size[1];
202-
global_wg_size[1] = 1;
203-
}
198+
utils::uvec3 out_limits = graph.logical_limits_of(out);
199+
utils::uvec3 global_wg_size = {
200+
out_limits[0] * (utils::div_up(out_limits, out_tile_nrows)),
201+
1,
202+
out_limit[2]};
204203

205204
utils::uvec3 local_wg_size{64, 1, 1};
206205
if (use_coop_algorithm) {

0 commit comments

Comments
 (0)