Skip to content

Commit b9a1762

Browse files
[ET-VK] Rearranging code in permute op shader to reduce heavy math ops and improve performance. (#7095)
Pull Request resolved: #7014 The diff rearranges Permute op shader code in executorch vulkan backend to reduce heavy math operations and improve performance. The change also include adding an extension to support explicit arithmetic types and changing the data type of the position variable to u16vec3. ghstack-source-id: 255546339 @exported-using-ghexport Differential Revision: [D66174765](https://our.internmc.facebook.com/intern/diff/D66174765/) Co-authored-by: Vivek Trivedi <[email protected]>
1 parent b8fbc48 commit b9a1762

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

backends/vulkan/runtime/graph/ops/glsl/permute.glsl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block {
3636

3737
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3838

39+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
40+
3941
void main() {
40-
const ivec3 pos = ivec3(gl_GlobalInvocationID);
42+
const u16vec3 pos = u16vec3(gl_GlobalInvocationID);
4143

4244
if (any(greaterThanEqual(pos, out_limits))) {
4345
return;
@@ -46,28 +48,34 @@ void main() {
4648
const int out_channel_4up = int(ch_info.x);
4749
const int in_channel_4up = int(ch_info.y);
4850
const int out_batch = int(sizes[3]);
49-
const int max_dst_index = out_batch * out_channel_4up;
5051
VEC4_T outval = VEC4_T(0.0);
52+
ivec4 v = ivec4(0); // holds b,c,h,w
53+
54+
v[out_ndims[2]] = pos.y;
55+
v[out_ndims[3]] = pos.x;
56+
57+
const int dst_index = pos.z << 2;
58+
int dst_out_index = dst_index / out_channel_4up;
59+
int dst_out_lane = dst_index % out_channel_4up;
5160

52-
for (int j = 0; j < 4; ++j) {
53-
int dst_index = pos.z * 4 + j;
54-
if (dst_index >= max_dst_index) {
61+
for (int j = 0; j < 4; ++j, ++dst_out_lane) {
62+
if (dst_out_index >= out_batch) {
5563
// out of range
5664
break;
5765
}
5866

59-
ivec4 v = ivec4(0); // holds b,c,h,w
60-
v[out_ndims[0]] = dst_index / out_channel_4up;
61-
v[out_ndims[1]] = dst_index % out_channel_4up;
62-
v[out_ndims[2]] = pos.y;
63-
v[out_ndims[3]] = pos.x;
67+
if (dst_out_lane == out_channel_4up) {
68+
dst_out_lane = 0;
69+
dst_out_index++;
70+
}
71+
72+
v[out_ndims[0]] = dst_out_index;
73+
v[out_ndims[1]] = dst_out_lane;
6474

6575
int src_index = v[0] * in_channel_4up + v[1];
66-
int w = v[3];
67-
int h = v[2];
6876

69-
VEC4_T inval = VEC4_T(texelFetch(image_in, ivec3(w, h, src_index / 4), 0));
70-
outval[j] = inval[src_index % 4];
77+
VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(v[3], v[2], src_index >> 2), 0));
78+
outval[j] = inval[src_index & 0x3];
7179
}
7280

7381
imageStore(image_out, pos, outval);

0 commit comments

Comments
 (0)