Skip to content

Commit c805499

Browse files
committed
[ET-VK] Using shared variable to store calculated output pose to free up registers and improve performance.
Pull Request resolved: #7475 This diff introduces a shared variable to store calculated output pose in conv2d_pw op to free up registers and improve performance. The code changes include adding a shared variable to hold calculated positions and modifying the existing code to use the shared variable. ghstack-source-id: 260166242 Differential Revision: [D67742567](https://our.internmc.facebook.com/intern/diff/D67742567/)
1 parent c913e17 commit c805499

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,17 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3434

3535
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
3636

37+
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
38+
shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE];
39+
3740
/*
3841
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
3942
* output tile for pointwise convolution is more efficient because the kernel
4043
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
4144
*/
4245
void main() {
4346
const uint16_t out_limits_y_scaled = uint16_t((out_limits.y + TILE_SIZE - 1) / TILE_SIZE);
47+
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
4448

4549
const u16vec3 gpos = u16vec3(
4650
gl_GlobalInvocationID.x / (out_limits_y_scaled * out_limits.z),
@@ -58,6 +62,7 @@ void main() {
5862
for (int x = 0; x < TILE_SIZE; ++x) {
5963
pos[i] = u16vec2(
6064
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
65+
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
6166
i++;
6267
}
6368
}
@@ -73,7 +78,7 @@ void main() {
7378
// the top-left element is in a region added by padding.
7479
u16vec2 ipos[TILE_SIZE * TILE_SIZE];
7580
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
76-
ipos[i] = pos[i].xy * u16vec2(stride) - u16vec2(padding);
81+
ipos[i] = pos[i] * u16vec2(stride) - u16vec2(padding);
7782
}
7883

7984
vec4 sum[TILE_SIZE * TILE_SIZE];
@@ -138,8 +143,9 @@ void main() {
138143
}
139144

140145
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
141-
if (all(lessThan(u16vec3(pos[i], gpos.z), out_limits))) {
142-
imageStore(t_out, u16vec3(pos[i], gpos.z), op(sum[i], out_min, out_max));
146+
const u16vec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
147+
if (all(lessThan(u16vec3(pos, gpos.z), out_limits))) {
148+
imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max));
143149
}
144150
}
145151
}

0 commit comments

Comments
 (0)