Skip to content

Commit b6d7a76

Browse files
committed
Update on "[ET-VK] Adding batch processing in x axis to conv2d dw shader by caching input texel for reuse."
This diff adds batch processing in the x axis to the conv2d dw shader by reusing input texel overlapping between consecutive tiles. The changes include modifying the glsl code for the conv2d dw output tile, adding a new parameter to the yaml file, and modifying the Convolution.cpp file to use the new parameter. Differential Revision: [D67868671](https://our.internmc.facebook.com/intern/diff/D67868671/) [ghstack-poisoned]
2 parents 9bfe284 + 519af1d commit b6d7a76

File tree

3 files changed

+34
-33
lines changed

3 files changed

+34
-33
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
#define op(X, A, B) ${OPERATOR}
2222

23-
#include "indexing_utils_u16.h"
23+
#include "indexing_utils.h"
2424

2525
layout(std430) buffer;
2626

@@ -43,35 +43,34 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4343
* output at a single output location.
4444
*/
4545
void main() {
46-
// x divided up by batch size is used to determine 3d position
47-
// y divided up by batch size is used to determine 3d position
46+
// x and y are divided by batch size to determine 3d position
4847
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
49-
const ivec2 out_limits_xy_scaled = ivec2(out_limits.xy + ivec2(BATCH_SIZE_X, BATCH_SIZE_Y) - 1) / ivec2(BATCH_SIZE_X, BATCH_SIZE_Y);
48+
const ivec2 out_limits_xy_scaled = (out_limits.xy + ivec2(BATCH_SIZE_X, BATCH_SIZE_Y) - 1) / ivec2(BATCH_SIZE_X, BATCH_SIZE_Y);
5049

51-
u16vec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits_xy_scaled.x, out_limits_xy_scaled.y);
50+
ivec3 pos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits_xy_scaled.x, out_limits_xy_scaled.y);
5251

5352
// scale pos.xy by batch sizes, because that's the top pixel to be processed
54-
pos.x *= uint16_t(BATCH_SIZE_X);
55-
pos.y *= uint16_t(BATCH_SIZE_Y);
53+
pos.x *= BATCH_SIZE_X;
54+
pos.y *= BATCH_SIZE_Y;
5655

5756
// do not process if top pixel does not fit within the output range
58-
if (any(greaterThanEqual(u16vec3(pos.x, pos.y, pos.z), out_limits))) {
57+
if (any(greaterThanEqual(pos, out_limits))) {
5958
return;
6059
}
6160

6261
// Compute the index of the top-left element of the overlay region. Negative
6362
// indices indicate that the top-left element is in a region added by padding.
64-
const u16vec2 ipos = pos.xy * u16vec2(stride) - u16vec2(padding);
63+
const ivec2 ipos = pos.xy * stride - padding;
6564

6665
// Compute the start and end of the input indices to load. Padding is assumed
6766
// to be constant 0 padding, so any reads from the padding region is skipped.
68-
const u16vec2 start = ipos;
69-
const u16vec2 end = ipos + u16vec2(overlay_region.xy);
67+
const ivec2 start = ipos;
68+
const ivec2 end = ipos + overlay_region.xy;
7069

7170
// sum outputs
7271
VEC4_T sum[BATCH_SIZE_Y][BATCH_SIZE_X];
7372

74-
sum[0][0] = texelFetch(t_bias, u16vec2(pos.z, 0), 0);
73+
sum[0][0] = texelFetch(t_bias, ivec2(pos.z, 0), 0);
7574
for (int y = 0; y < BATCH_SIZE_Y; y++) {
7675
for (int x = 0; x < BATCH_SIZE_X; x++) {
7776
sum[y][x] = sum[0][0];
@@ -84,39 +83,39 @@ void main() {
8483
// array to store kernel data of previous y
8584
VEC4_T prev_kernel_line[TILE_SIZE];
8685

87-
uint16_t kx = uint16_t(0);
88-
for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1); y += uint16_t(dilation.y), i++) {
89-
for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE + BATCH_SIZE_X - 1); x += uint16_t(dilation.x), j++) {
90-
in_texels[int(j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0);
86+
int kx = 0;
87+
for (int y = start.y, i = 0; i < TILE_SIZE + BATCH_SIZE_Y - 1; y += dilation.y, i++) {
88+
for (int x = start.x, j = 0; j < TILE_SIZE + BATCH_SIZE_X - 1; x += dilation.x, j++) {
89+
in_texels[j] = texelFetch(t_in, ivec3(x, y, pos.z), 0);
9190
}
9291

9392
// from 2nd iteration onwards accumulate dot product in 2nd sum
9493
// based on kernel line data fetched in previous iteration and input texel from this iteration
95-
if (i > uint16_t(0)) {
96-
for (uint16_t s = uint16_t(0); s < uint16_t(BATCH_SIZE_X); s++) {
97-
for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++) {
98-
sum[1][int(s)] = fma(in_texels[int(j+s)], prev_kernel_line[int(j)], sum[1][int(s)]);
94+
if (i > 0) {
95+
for (int j = 0; j < TILE_SIZE; j++) {
96+
for (int s = 0; s < BATCH_SIZE_X; s++) {
97+
sum[1][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[1][s]);
9998
}
10099
}
101100
}
102101

103102
// accumulate dot product in 1st sum only until tile size
104-
if (i < uint16_t(TILE_SIZE)) {
105-
for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++, kx++) {
106-
prev_kernel_line[int(j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0);
107-
for (uint16_t s = uint16_t(0); s < uint16_t(BATCH_SIZE_X); s++) {
108-
sum[0][int(s)] = fma(in_texels[int(j+s)], prev_kernel_line[int(j)], sum[0][int(s)]);
103+
if (i < int(TILE_SIZE)) {
104+
for (int j = 0; j < TILE_SIZE; j++, kx++) {
105+
prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0);
106+
for (int s = 0; s < BATCH_SIZE_X; s++) {
107+
sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
109108
}
110109
}
111110
}
112111
}
113112

114-
for (int i = 0; i < BATCH_SIZE_Y; i++) {
115-
for (int j = 0; j < BATCH_SIZE_X; j++) {
116-
if (any(greaterThanEqual(u16vec3(pos.x + j, pos.y + i, pos.z), out_limits))) {
113+
for (int y = 0; y < BATCH_SIZE_Y; y++) {
114+
for (int x = 0; x < BATCH_SIZE_X; x++) {
115+
if (any(greaterThanEqual(ivec3(pos.x + x, pos.y + y, pos.z), out_limits))) {
117116
continue;
118117
}
119-
imageStore(t_out, u16vec3(pos.x + j, pos.y + i, pos.z), op(sum[i][j], out_min, out_max));
118+
imageStore(t_out, ivec3(pos.x + x, pos.y + y, pos.z), op(sum[y][x], out_min, out_max));
120119
}
121120
}
122121
}

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,11 @@ ivec3 lpos_to_pos(const ivec3 lpos, const ivec4 axis_map) {
223223
return pos;
224224
}
225225

226+
ivec3 idx_to_ipos_x_wise(uint idx, int size_x, int size_y) {
227+
const uint div_by_x = idx / size_x;
228+
return ivec3(idx % size_x, div_by_x % size_y, div_by_x / size_y);
229+
}
230+
226231
#ifdef USING_BUFFER
227232
#define load_texel(buf, idx) buf[idx]
228233
#elif defined(USING_TEXTURE2D)

backends/vulkan/runtime/graph/ops/glsl/indexing_utils_u16.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313

1414
u16vec3 idx_to_u16pos_x_wise(uint idx, int size_x, int size_y) {
1515
const uint div_by_x = idx / size_x;
16-
return u16vec3(
17-
idx % size_x,
18-
div_by_x % size_y,
19-
div_by_x / size_y);
16+
return u16vec3(idx % size_x, div_by_x % size_y, div_by_x / size_y);
2017
}
2118

2219
#endif // INDEXING_UTILS_U16_H

0 commit comments

Comments
 (0)