Skip to content

Commit 6cb9037

Browse files
[ET-VK] Adding a common utility function to calculate 3d output position based on unique index. (#7564)
Pull Request resolved: #7522 This diff adds an indexing utils header file used in Vulkan backend of Executorch. The header file includes functions for converting a global index to u16 indices based on input sizes. ghstack-source-id: 260707858 @exported-using-ghexport Differential Revision: [D67821941](https://our.internmc.facebook.com/intern/diff/D67821941/) Co-authored-by: Vivek Trivedi <[email protected]>
1 parent ca1f760 commit 6cb9037

File tree

4 files changed

+27
-18
lines changed

4 files changed

+27
-18
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#define op(X, A, B) ${OPERATOR}
1616

17-
#include "indexing_utils.h"
17+
#include "indexing_utils_u16.h"
1818

1919
layout(std430) buffer;
2020

@@ -35,10 +35,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3535
* output at a single output location.
3636
*/
3737
void main() {
38-
const ivec3 pos = ivec3(
39-
gl_GlobalInvocationID.x % out_limits.x,
40-
(gl_GlobalInvocationID.x / out_limits.x) % out_limits.y,
41-
gl_GlobalInvocationID.x / (out_limits.x * out_limits.y));
38+
const ivec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y);
4239

4340
if (any(greaterThanEqual(pos, out_limits))) {
4441
return;

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
#define op(X, A, B) ${OPERATOR}
2020

21-
#include "indexing_utils.h"
21+
#include "indexing_utils_u16.h"
2222

2323
layout(std430) buffer;
2424

@@ -43,12 +43,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4343
void main() {
4444
// y divided up by batch size is used to determine 3d position
4545
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
46-
const uint out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1) / BATCH_SIZE_Y;
46+
const int out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1) / BATCH_SIZE_Y;
4747

48-
u16vec3 pos = u16vec3(
49-
gl_GlobalInvocationID.x % out_limits.x,
50-
((gl_GlobalInvocationID.x / out_limits.x) % out_limits_y_scaled),
51-
gl_GlobalInvocationID.x / (out_limits.x * out_limits_y_scaled));
48+
u16vec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits_y_scaled);
5249

5350
// scale pos.y by batch size, because that's the top pixel to be processed
5451
pos.y *= uint16_t(BATCH_SIZE_Y);

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#define op(X, A, B) ${OPERATOR}
1818

19-
#include "indexing_utils.h"
19+
#include "indexing_utils_u16.h"
2020

2121
layout(std430) buffer;
2222

@@ -43,13 +43,10 @@ shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroup
4343
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
4444
*/
4545
void main() {
46-
const uvec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
46+
const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
4747
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
4848

49-
const u16vec3 gpos = u16vec3(
50-
gl_GlobalInvocationID.x % out_limits_scaled.x,
51-
(gl_GlobalInvocationID.x / out_limits_scaled.x) % out_limits_scaled.y,
52-
gl_GlobalInvocationID.x / (out_limits_scaled.x * out_limits_scaled.y));
49+
const u16vec3 gpos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits_scaled.x, out_limits_scaled.y);
5350

5451
// Output position for TILE_SIZE = 2
5552
// +--------+--------+
@@ -98,7 +95,6 @@ void main() {
9895
const vec4 ktex_2 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(2, 0));
9996
const vec4 ktex_3 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(3, 0));
10097

101-
10298
#pragma unroll
10399
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
104100
const vec4 in_tex = texelFetch(t_in, u16vec3(ipos[i], z4), 0);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef INDEXING_UTILS_U16_H
10+
#define INDEXING_UTILS_U16_H
11+
12+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
13+
14+
u16vec3 idx_to_u16pos_x_wise(uint idx, int size_x, int size_y) {
15+
const uint div_by_x = idx / size_x;
16+
return u16vec3(idx % size_x, div_by_x % size_y, div_by_x / size_y);
17+
}
18+
19+
#endif // INDEXING_UTILS_U16_H

0 commit comments

Comments
 (0)