Skip to content

Commit eeffa26

Browse files
authored
[ET-VK] Changing all conv 2d pw ints from uint16 to int since it slightly improves perf.
Differential Revision: D67906023 Pull Request resolved: #7545
1 parent 8b5adb8 commit eeffa26

File tree

3 files changed

+20
-41
lines changed

3 files changed

+20
-41
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#define op(X, A, B) ${OPERATOR}
1616

17-
#include "indexing_utils_u16.h"
17+
#include "indexing_utils.h"
1818

1919
layout(std430) buffer;
2020

@@ -35,7 +35,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3535
* output at a single output location.
3636
*/
3737
void main() {
38-
const ivec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y);
38+
const ivec3 pos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y);
3939

4040
if (any(greaterThanEqual(pos, out_limits))) {
4141
return;

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#define op(X, A, B) ${OPERATOR}
1818

19-
#include "indexing_utils_u16.h"
19+
#include "indexing_utils.h"
2020

2121
layout(std430) buffer;
2222

@@ -32,10 +32,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3232

3333
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3434

35-
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
36-
3735
// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
38-
shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE];
36+
shared ivec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE];
3937

4038
/*
4139
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
@@ -46,18 +44,18 @@ void main() {
4644
const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
4745
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
4846

49-
const u16vec3 gpos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits_scaled.x, out_limits_scaled.y);
47+
const ivec3 gpos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits_scaled.x, out_limits_scaled.y);
5048

5149
// Output position for TILE_SIZE = 2
5250
// +--------+--------+
5351
// | pos[0] | pos[1] |
5452
// +--------+--------+
5553
// | pos[2] | pos[3] |
5654
// +--------+--------+
57-
u16vec2 pos[TILE_SIZE * TILE_SIZE];
55+
ivec2 pos[TILE_SIZE * TILE_SIZE];
5856
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
5957
for (int x = 0; x < TILE_SIZE; ++x) {
60-
pos[i] = u16vec2(
58+
pos[i] = ivec2(
6159
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
6260
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
6361
i++;
@@ -66,38 +64,38 @@ void main() {
6664

6765
// If the top left position is out of bounds, then this invocation will have
6866
// no work to do.
69-
if (any(greaterThanEqual(u16vec3(pos[0], gpos.z), out_limits))) {
67+
if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits))) {
7068
return;
7169
}
7270

7371
// Compute the index of the input texture that needs to be loaded for each
7472
// output position. Note that negative indices can be produced indicating that
7573
// the top-left element is in a region added by padding.
76-
u16vec2 ipos[TILE_SIZE * TILE_SIZE];
74+
ivec2 ipos[TILE_SIZE * TILE_SIZE];
7775
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
78-
ipos[i] = pos[i] * u16vec2(stride) - u16vec2(padding);
76+
ipos[i] = pos[i] * stride - padding;
7977
}
8078

8179
vec4 sum[TILE_SIZE * TILE_SIZE];
82-
sum[0] = texelFetch(t_bias, u16vec2(gpos.z, 0), 0);
80+
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
8381
for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
8482
sum[i] = sum[0];
8583
}
8684

8785
int z4 = 0;
8886
// Since the kernel is 1x1, we only have to loop over the depth dimension.
89-
for (uint16_t z = uint16_t(0); z < uint16_t(in_group_size); z += uint16_t(4), ++z4) {
87+
for (int z = 0; z < in_group_size; z += 4, ++z4) {
9088
// During prepacking, the weight tensor has been permuted so that the
9189
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
9290
// the z-axis.
93-
const vec4 ktex_0 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(0, 0));
94-
const vec4 ktex_1 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(1, 0));
95-
const vec4 ktex_2 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(2, 0));
96-
const vec4 ktex_3 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(3, 0));
91+
const vec4 ktex_0 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(0, 0));
92+
const vec4 ktex_1 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(1, 0));
93+
const vec4 ktex_2 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(2, 0));
94+
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));
9795

9896
#pragma unroll
9997
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
100-
const vec4 in_tex = texelFetch(t_in, u16vec3(ipos[i], z4), 0);
98+
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
10199
// For 2x2 tile size algorithm works as follows.
102100
// To explain the calculations below, the contents of one in_tex and the
103101
// group of 4 texels loaded from t_kernel are shown:
@@ -139,9 +137,9 @@ void main() {
139137
}
140138

141139
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
142-
const u16vec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
143-
if (all(lessThan(u16vec3(pos, gpos.z), out_limits))) {
144-
imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max));
140+
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
141+
if (all(lessThan(ivec3(pos, gpos.z), out_limits))) {
142+
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
145143
}
146144
}
147145
}

backends/vulkan/runtime/graph/ops/glsl/indexing_utils_u16.h

Lines changed: 0 additions & 19 deletions
This file was deleted.

0 commit comments

Comments
 (0)