Skip to content

Commit 0bce3c6

Browse files
committed
[ET-VK] Adding batch processing in x axis to conv2d dw shader by caching input texel for reuse.
This diff adds batch processing in the x axis to the conv2d dw shader by reusing input texel overlapping between consecutive tiles. The changes include modifying the glsl code for the conv2d dw output tile, adding a new parameter to the yaml file, and modifying the Convolution.cpp file to use the new parameter. Differential Revision: [D67868671](https://our.internmc.facebook.com/intern/diff/D67868671/) ghstack-source-id: 260383028 Pull Request resolved: #7526
1 parent 1cb5d97 commit 0bce3c6

File tree

3 files changed

+29
-16
lines changed

3 files changed

+29
-16
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#define TILE_SIZE ${TILE_SIZE}
1616

17+
#define BATCH_SIZE_X ${BATCH_SIZE_X}
18+
1719
#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
1820

1921
#define op(X, A, B) ${OPERATOR}
@@ -41,13 +43,15 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4143
* output at a single output location.
4244
*/
4345
void main() {
46+
// x divided up by batch size is used to determine 3d position
4447
// y divided up by batch size is used to determine 3d position
4548
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
46-
const int out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1) / BATCH_SIZE_Y;
49+
const ivec2 out_limits_xy_scaled = ivec2(out_limits.xy + ivec2(BATCH_SIZE_X, BATCH_SIZE_Y) - 1) / ivec2(BATCH_SIZE_X, BATCH_SIZE_Y);
4750

48-
u16vec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits_y_scaled);
51+
u16vec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits_xy_scaled.x, out_limits_xy_scaled.y);
4952

50-
// scale pos.y by batch size, because that's the top pixel to be processed
53+
// scale pos.xy by batch sizes, because that's the top pixel to be processed
54+
pos.x *= uint16_t(BATCH_SIZE_X);
5155
pos.y *= uint16_t(BATCH_SIZE_Y);
5256

5357
// do not process if top pixel does not fit within the output range
@@ -65,46 +69,54 @@ void main() {
6569
const u16vec2 end = ipos + u16vec2(overlay_region.xy);
6670

6771
// sum outputs
68-
VEC4_T sum[BATCH_SIZE_Y];
72+
VEC4_T sum[BATCH_SIZE_Y][BATCH_SIZE_X];
6973

70-
sum[0] = texelFetch(t_bias, u16vec2(pos.z, 0), 0);
71-
for (int i = 1; i < BATCH_SIZE_Y; i++) {
72-
sum[i] = sum[0];
74+
sum[0][0] = texelFetch(t_bias, u16vec2(pos.z, 0), 0);
75+
for (int y = 0; y < BATCH_SIZE_Y; y++) {
76+
for (int x = 0; x < BATCH_SIZE_X; x++) {
77+
sum[y][x] = sum[0][0];
78+
}
7379
}
7480

7581
// array to store input texels
76-
VEC4_T in_texels[TILE_SIZE];
82+
VEC4_T in_texels[TILE_SIZE + BATCH_SIZE_X - 1];
7783

7884
// array to store kernel data of previous y
7985
VEC4_T prev_kernel_line[TILE_SIZE];
8086

8187
uint16_t kx = uint16_t(0);
8288
for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1); y += uint16_t(dilation.y), i++) {
83-
for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) {
89+
for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE + BATCH_SIZE_X - 1); x += uint16_t(dilation.x), j++) {
8490
in_texels[int(j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0);
8591
}
8692

8793
// from 2nd iteration onwards accumulate dot product in 2nd sum
8894
// based on kernel line data fetched in previous iteration and input texel from this iteration
8995
if (i > uint16_t(0)) {
90-
for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++) {
91-
sum[1] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[1]);
96+
for (uint16_t s = uint16_t(0); s < uint16_t(BATCH_SIZE_X); s++) {
97+
for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++) {
98+
sum[1][int(s)] = fma(in_texels[int(j+s)], prev_kernel_line[int(j)], sum[1][int(s)]);
99+
}
92100
}
93101
}
94102

95103
// accumulate dot product in 1st sum only until tile size
96104
if (i < uint16_t(TILE_SIZE)) {
97105
for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++, kx++) {
98106
prev_kernel_line[int(j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0);
99-
sum[0] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[0]);
107+
for (uint16_t s = uint16_t(0); s < uint16_t(BATCH_SIZE_X); s++) {
108+
sum[0][int(s)] = fma(in_texels[int(j+s)], prev_kernel_line[int(j)], sum[0][int(s)]);
109+
}
100110
}
101111
}
102112
}
103113

104114
for (int i = 0; i < BATCH_SIZE_Y; i++) {
105-
if (any(greaterThanEqual(u16vec3(pos.x, pos.y + i, pos.z), out_limits))) {
106-
continue;
115+
for (int j = 0; j < BATCH_SIZE_X; j++) {
116+
if (any(greaterThanEqual(u16vec3(pos.x + j, pos.y + i, pos.z), out_limits))) {
117+
continue;
118+
}
119+
imageStore(t_out, u16vec3(pos.x + j, pos.y + i, pos.z), op(sum[i][j], out_min, out_max));
107120
}
108-
imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max));
109121
}
110122
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ conv2d_dw_output_tile:
1010
NDIM: 3
1111
DTYPE: float
1212
TILE_SIZE: 3
13+
BATCH_SIZE_X: 4
1314
BATCH_SIZE_Y: 2
1415
generate_variant_forall:
1516
DTYPE:

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ utils::uvec3 create_conv2d_global_wg_size(
299299
} else if (method == Conv2dMethod::Depthwise) {
300300
const utils::uvec3 image_extents = graph.logical_limits_of(out);
301301
return {
302-
utils::div_up(image_extents[0u], 1u),
302+
utils::div_up(image_extents[0u], 4u),
303303
utils::div_up(image_extents[1u], 2u),
304304
image_extents[2u]};
305305
} else {

0 commit comments

Comments
 (0)