Skip to content

Commit d790787

Browse files
authored
[ET-VK] Fixing conv2d dw incorrect output when stride != dilation issue.
Differential Revision: D67908916 Pull Request resolved: #7595
1 parent f027deb commit d790787

File tree

4 files changed

+123
-9
lines changed

4 files changed

+123
-9
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#define TILE_SIZE ${TILE_SIZE}
1616

17+
#define STRIDE_EQ_DILATION ${STRIDE_EQ_DILATION}
18+
1719
#define BATCH_SIZE_X ${BATCH_SIZE_X}
1820

1921
#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
@@ -40,6 +42,8 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4042
* Computes a depthwise convolution. Each shader invocation calculates the
4143
* output at a single output location.
4244
*/
45+
46+
#if STRIDE_EQ_DILATION
4347
void main() {
4448
// x and y are divided by batch size to determine 3d position
4549
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
@@ -121,3 +125,42 @@ void main() {
121125
}
122126
}
123127
}
128+
129+
#else
130+
void main() {
131+
const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x;
132+
const ivec3 pos = ivec3(
133+
gl_GlobalInvocationID.x % out_limits.x,
134+
div_by_x % out_limits.y,
135+
div_by_x / out_limits.y);
136+
137+
if (any(greaterThanEqual(pos, out_limits))) {
138+
return;
139+
}
140+
141+
// Compute the index of the top-left element of the overlay region. Negative
142+
// indices indicate that the top-left element is in a region added by padding.
143+
const ivec2 ipos = pos.xy * stride - padding;
144+
145+
// Compute the start and end of the input indices to load. Padding is assumed
146+
// to be constant 0 padding, so any reads from the padding region is skipped.
147+
const ivec2 start = ipos;
148+
const ivec2 end = ipos + overlay_region.xy;
149+
150+
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
151+
int kx = 0;
152+
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
153+
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
154+
// The weight kernel was rearranged such that every NxN filter is
155+
// flattened to fit in one row. Each filter was then stacked on top of
156+
// each other vertically.
157+
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
158+
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
159+
kx++;
160+
}
161+
}
162+
163+
imageStore(t_out, pos, op(sum, out_min, out_max));
164+
}
165+
166+
#endif

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ conv2d_dw_output_tile:
1212
TILE_SIZE: 3
1313
BATCH_SIZE_X: 4
1414
BATCH_SIZE_Y: 2
15+
STRIDE_EQ_DILATION: 0
1516
generate_variant_forall:
1617
DTYPE:
1718
- VALUE: half
@@ -25,3 +26,15 @@ conv2d_dw_output_tile:
2526
- NAME: conv2d_dw_output_tile_5x5_clamp
2627
OPERATOR: clamp(X, A, B)
2728
TILE_SIZE: 5
29+
- NAME: conv2d_dw_sed_output_tile_3x3
30+
STRIDE_EQ_DILATION: 1
31+
- NAME: conv2d_dw_sed_output_tile_3x3_clamp
32+
OPERATOR: clamp(X, A, B)
33+
STRIDE_EQ_DILATION: 1
34+
- NAME: conv2d_dw_sed_output_tile_5x5
35+
TILE_SIZE: 5
36+
STRIDE_EQ_DILATION: 1
37+
- NAME: conv2d_dw_sed_output_tile_5x5_clamp
38+
OPERATOR: clamp(X, A, B)
39+
TILE_SIZE: 5
40+
STRIDE_EQ_DILATION: 1

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,17 @@ vkapi::ShaderInfo get_conv2d_shader(
126126
const bool prepack_weights,
127127
const Conv2dMethod method,
128128
const ValueRef weight,
129-
const bool clamp_out = false) {
129+
const bool clamp_out = false,
130+
const bool stride_equals_dilation = false) {
130131
std::string kernel_name;
131132
kernel_name.reserve(kShaderNameReserve);
132133
switch (method) {
133134
case Conv2dMethod::Depthwise:
134135
kernel_name = "conv2d_dw";
135136
if (!prepack_weights) {
137+
if (stride_equals_dilation) {
138+
kernel_name += "_sed";
139+
}
136140
const auto& weight_sizes = graph.get_tref(weight)->sizes;
137141
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
138142
kernel_name += "_output_tile_3x3";
@@ -286,22 +290,37 @@ Conv2dMethod get_conv2d_method(
286290
return Conv2dMethod::SlidingWindow;
287291
}
288292

293+
utils::uvec2 get_conv2d_dw_dispatch_divisor(
294+
const std::vector<int64_t>& weight_sizes) {
295+
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
296+
return {4u, 2u};
297+
}
298+
if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) {
299+
return {4u, 2u};
300+
}
301+
return {4u, 2u};
302+
}
303+
289304
utils::uvec3 create_conv2d_global_wg_size(
290305
ComputeGraph& graph,
291306
const Conv2dMethod method,
292-
const ValueRef out) {
307+
const ValueRef out,
308+
const ValueRef weight_data,
309+
const bool stride_equals_dilation) {
293310
if (method == Conv2dMethod::Pointwise) {
294311
const utils::uvec3 image_extents = graph.logical_limits_of(out);
295312
return {
296313
utils::div_up(image_extents[0u], 2u),
297314
utils::div_up(image_extents[1u], 2u),
298315
image_extents[2u]};
299-
} else if (method == Conv2dMethod::Depthwise) {
300-
const utils::uvec3 image_extents = graph.logical_limits_of(out);
316+
} else if (method == Conv2dMethod::Depthwise && stride_equals_dilation) {
317+
const utils::uvec3 image_extents = graph.create_global_wg_size(out);
318+
const utils::uvec2 div =
319+
get_conv2d_dw_dispatch_divisor(graph.get_tref(weight_data)->sizes);
301320
return {
302-
utils::div_up(image_extents[0u], 4u),
303-
utils::div_up(image_extents[1u], 2u),
304-
image_extents[2u]};
321+
utils::div_up(image_extents[0], div[0]),
322+
utils::div_up(image_extents[1], div[1]),
323+
image_extents[2]};
305324
} else {
306325
return graph.create_global_wg_size(out);
307326
}
@@ -364,6 +383,10 @@ void add_conv2d_node(
364383
Conv2dParams extra_params =
365384
create_conv2d_params(graph, weight_data, kernel_params, transposed_val);
366385

386+
const bool stride_equals_dilation =
387+
(kernel_params.stride[0] == kernel_params.dilation[0] &&
388+
kernel_params.stride[1] == kernel_params.dilation[1]);
389+
367390
OutputParams out_params = {out_min_val, out_max_val};
368391

369392
check_conv2d_params(kernel_params, transposed_val);
@@ -374,9 +397,11 @@ void add_conv2d_node(
374397
/*prepack_weights = */ false,
375398
method,
376399
weight_data,
377-
clamp_out);
400+
clamp_out,
401+
stride_equals_dilation);
378402

379-
utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out);
403+
utils::uvec3 wg_size = create_conv2d_global_wg_size(
404+
graph, method, out, weight_data, stride_equals_dilation);
380405

381406
if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
382407
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};

backends/vulkan/test/op_tests/cases.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,39 @@ def get_conv_inputs():
348348
[0, 0],
349349
1,
350350
),
351+
(
352+
(1, 4, 234, 234),
353+
(4, 1, 3, 3),
354+
(4,),
355+
[2, 1],
356+
[1, 1],
357+
[1, 1],
358+
False,
359+
[0, 0],
360+
4,
361+
),
362+
(
363+
(1, 4, 234, 234),
364+
(4, 1, 3, 3),
365+
(4,),
366+
[1, 2],
367+
[1, 1],
368+
[1, 1],
369+
False,
370+
[0, 0],
371+
4,
372+
),
373+
(
374+
(1, 4, 234, 234),
375+
(4, 1, 3, 3),
376+
(4,),
377+
[2, 2],
378+
[1, 1],
379+
[1, 1],
380+
False,
381+
[0, 0],
382+
4,
383+
),
351384
]
352385
)
353386
return test_suite

0 commit comments

Comments
 (0)