Skip to content

Commit 3ead5c9

Browse files
authored
[ET-VK] Adding batch processing to conv2d dw shader by caching input texel and kernel values for reuse.
Differential Revision: D67774359 Pull Request resolved: #7485
1 parent cfeba33 commit 3ead5c9

File tree

3 files changed

+59
-13
lines changed

3 files changed

+59
-13
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#define TILE_SIZE ${TILE_SIZE}
1616

17+
#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
18+
1719
#define op(X, A, B) ${OPERATOR}
1820

1921
#include "indexing_utils.h"
@@ -39,12 +41,20 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3941
* output at a single output location.
4042
*/
4143
void main() {
42-
const u16vec3 pos = u16vec3(
44+
// y divided up by batch size is used to determine 3d position
45+
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
46+
const uint out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1) / BATCH_SIZE_Y;
47+
48+
u16vec3 pos = u16vec3(
4349
gl_GlobalInvocationID.x % out_limits.x,
44-
(gl_GlobalInvocationID.x / out_limits.x) % out_limits.y,
45-
gl_GlobalInvocationID.x / (out_limits.x * out_limits.y));
50+
((gl_GlobalInvocationID.x / out_limits.x) % out_limits_y_scaled),
51+
gl_GlobalInvocationID.x / (out_limits.x * out_limits_y_scaled));
4652

47-
if (any(greaterThanEqual(pos, out_limits))) {
53+
// scale pos.y by batch size, because that's the top pixel to be processed
54+
pos.y *= uint16_t(BATCH_SIZE_Y);
55+
56+
// do not process if top pixel does not fit within the output range
57+
if (any(greaterThanEqual(u16vec3(pos.x, pos.y, pos.z), out_limits))) {
4858
return;
4959
}
5060

@@ -57,18 +67,47 @@ void main() {
5767
const u16vec2 start = ipos;
5868
const u16vec2 end = ipos + u16vec2(overlay_region.xy);
5969

60-
VEC4_T sum = texelFetch(t_bias, u16vec2(pos.z, 0), 0);
70+
// sum outputs
71+
VEC4_T sum[BATCH_SIZE_Y];
72+
73+
sum[0] = texelFetch(t_bias, u16vec2(pos.z, 0), 0);
74+
for (int i = 1; i < BATCH_SIZE_Y; i++) {
75+
sum[i] = sum[0];
76+
}
77+
78+
// array to store input texels
79+
VEC4_T in_texels[TILE_SIZE];
80+
81+
// array to store kernel data of previous y
82+
VEC4_T prev_kernel_line[TILE_SIZE];
83+
6184
uint16_t kx = uint16_t(0);
62-
for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE); y += uint16_t(dilation.y), i++) {
85+
for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1); y += uint16_t(dilation.y), i++) {
6386
for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) {
64-
// The weight kernel was rearranged such that every NxN filter is
65-
// flattened to fit in one row. Each filter was then stacked on top of
66-
// each other vertically.
67-
const vec4 in_texel = texelFetch(t_in, u16vec3(x, y, pos.z), 0);
68-
sum = fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0), sum);
69-
kx++;
87+
in_texels[int(j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0);
88+
}
89+
90+
// from 2nd iteration onwards accumulate dot product in 2nd sum
91+
// based on kernel line data fetched in previous iteration and input texel from this iteration
92+
if (i > uint16_t(0)) {
93+
for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++) {
94+
sum[1] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[1]);
95+
}
96+
}
97+
98+
// accumulate dot product in 1st sum only until tile size
99+
if (i < uint16_t(TILE_SIZE)) {
100+
for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++, kx++) {
101+
prev_kernel_line[int(j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0);
102+
sum[0] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[0]);
103+
}
70104
}
71105
}
72106

73-
imageStore(t_out, pos, op(sum, out_min, out_max));
107+
for (int i = 0; i < BATCH_SIZE_Y; i++) {
108+
if (any(greaterThanEqual(u16vec3(pos.x, pos.y + i, pos.z), out_limits))) {
109+
continue;
110+
}
111+
imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max));
112+
}
74113
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ conv2d_dw_output_tile:
1010
NDIM: 3
1111
DTYPE: float
1212
TILE_SIZE: 3
13+
BATCH_SIZE_Y: 2
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: half

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,12 @@ utils::uvec3 create_conv2d_global_wg_size(
296296
utils::div_up(image_extents[0u], 2u),
297297
utils::div_up(image_extents[1u], 2u),
298298
image_extents[2u]};
299+
} else if (method == Conv2dMethod::Depthwise) {
300+
const utils::uvec3 image_extents = graph.logical_limits_of(out);
301+
return {
302+
utils::div_up(image_extents[0u], 1u),
303+
utils::div_up(image_extents[1u], 2u),
304+
image_extents[2u]};
299305
} else {
300306
return graph.create_global_wg_size(out);
301307
}

0 commit comments

Comments
 (0)