Skip to content

Commit 8749884

Browse files
committed
[ET-VK] Fixing conv2d dw incorrect output when stride != dilation issue.
Pull Request resolved: #7595 This diff moves current implementation of conv2d dw as a special case when stride equals dilation in the Vulkan backend of Executorch, since that's the only time this kind of caching is possible. If stride does not equal dilation the old implementation is used. Additional test cases are added to ensure computation is correct when stride != dilation. ghstack-source-id: 260951737 @exported-using-ghexport Differential Revision: [D67908916](https://our.internmc.facebook.com/intern/diff/D67908916/)
1 parent 25d8f15 commit 8749884

File tree

4 files changed

+123
-9
lines changed

4 files changed

+123
-9
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#define TILE_SIZE ${TILE_SIZE}
1616

17+
#define STRIDE_EQ_DILATION ${STRIDE_EQ_DILATION}
18+
1719
#define BATCH_SIZE_X ${BATCH_SIZE_X}
1820

1921
#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
@@ -42,6 +44,8 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4244
* Computes a depthwise convolution. Each shader invocation calculates the
4345
* output at a single output location.
4446
*/
47+
48+
#if STRIDE_EQ_DILATION
4549
void main() {
4650
// x and y are divided by batch size to determine 3d position
4751
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
@@ -123,3 +127,42 @@ void main() {
123127
}
124128
}
125129
}
130+
131+
#else
132+
void main() {
133+
const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x;
134+
const ivec3 pos = ivec3(
135+
gl_GlobalInvocationID.x % out_limits.x,
136+
div_by_x % out_limits.y,
137+
div_by_x / out_limits.y);
138+
139+
if (any(greaterThanEqual(pos, out_limits))) {
140+
return;
141+
}
142+
143+
// Compute the index of the top-left element of the overlay region. Negative
144+
// indices indicate that the top-left element is in a region added by padding.
145+
const ivec2 ipos = pos.xy * stride - padding;
146+
147+
// Compute the start and end of the input indices to load. Padding is assumed
148+
// to be constant 0 padding, so any reads from the padding region is skipped.
149+
const ivec2 start = ipos;
150+
const ivec2 end = ipos + overlay_region.xy;
151+
152+
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
153+
int kx = 0;
154+
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
155+
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
156+
// The weight kernel was rearranged such that every NxN filter is
157+
// flattened to fit in one row. Each filter was then stacked on top of
158+
// each other vertically.
159+
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
160+
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
161+
kx++;
162+
}
163+
}
164+
165+
imageStore(t_out, pos, op(sum, out_min, out_max));
166+
}
167+
168+
#endif

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ conv2d_dw_output_tile:
1212
TILE_SIZE: 3
1313
BATCH_SIZE_X: 4
1414
BATCH_SIZE_Y: 2
15+
STRIDE_EQ_DILATION: 0
1516
generate_variant_forall:
1617
DTYPE:
1718
- VALUE: half
@@ -25,3 +26,15 @@ conv2d_dw_output_tile:
2526
- NAME: conv2d_dw_output_tile_5x5_clamp
2627
OPERATOR: clamp(X, A, B)
2728
TILE_SIZE: 5
29+
- NAME: conv2d_dw_sed_output_tile_3x3
30+
STRIDE_EQ_DILATION: 1
31+
- NAME: conv2d_dw_sed_output_tile_3x3_clamp
32+
OPERATOR: clamp(X, A, B)
33+
STRIDE_EQ_DILATION: 1
34+
- NAME: conv2d_dw_sed_output_tile_5x5
35+
TILE_SIZE: 5
36+
STRIDE_EQ_DILATION: 1
37+
- NAME: conv2d_dw_sed_output_tile_5x5_clamp
38+
OPERATOR: clamp(X, A, B)
39+
TILE_SIZE: 5
40+
STRIDE_EQ_DILATION: 1

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,17 @@ vkapi::ShaderInfo get_conv2d_shader(
126126
const bool prepack_weights,
127127
const Conv2dMethod method,
128128
const ValueRef weight,
129-
const bool clamp_out = false) {
129+
const bool clamp_out = false,
130+
const bool stride_equals_dilation = false) {
130131
std::string kernel_name;
131132
kernel_name.reserve(kShaderNameReserve);
132133
switch (method) {
133134
case Conv2dMethod::Depthwise:
134135
kernel_name = "conv2d_dw";
135136
if (!prepack_weights) {
137+
if (stride_equals_dilation) {
138+
kernel_name += "_sed";
139+
}
136140
const auto& weight_sizes = graph.get_tref(weight)->sizes;
137141
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
138142
kernel_name += "_output_tile_3x3";
@@ -286,22 +290,37 @@ Conv2dMethod get_conv2d_method(
286290
return Conv2dMethod::SlidingWindow;
287291
}
288292

293+
utils::uvec2 get_conv2d_dw_dispatch_divisor(
294+
const std::vector<int64_t>& weight_sizes) {
295+
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
296+
return {4u, 2u};
297+
}
298+
if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) {
299+
return {4u, 2u};
300+
}
301+
return {4u, 2u};
302+
}
303+
289304
utils::uvec3 create_conv2d_global_wg_size(
290305
ComputeGraph& graph,
291306
const Conv2dMethod method,
292-
const ValueRef out) {
307+
const ValueRef out,
308+
const ValueRef weight_data,
309+
const bool stride_equals_dilation) {
293310
if (method == Conv2dMethod::Pointwise) {
294311
const utils::uvec3 image_extents = graph.logical_limits_of(out);
295312
return {
296313
utils::div_up(image_extents[0u], 2u),
297314
utils::div_up(image_extents[1u], 2u),
298315
image_extents[2u]};
299-
} else if (method == Conv2dMethod::Depthwise) {
300-
const utils::uvec3 image_extents = graph.logical_limits_of(out);
316+
} else if (method == Conv2dMethod::Depthwise && stride_equals_dilation) {
317+
const utils::uvec3 image_extents = graph.create_global_wg_size(out);
318+
const utils::uvec2 div =
319+
get_conv2d_dw_dispatch_divisor(graph.get_tref(weight_data)->sizes);
301320
return {
302-
utils::div_up(image_extents[0u], 4u),
303-
utils::div_up(image_extents[1u], 2u),
304-
image_extents[2u]};
321+
utils::div_up(image_extents[0], div[0]),
322+
utils::div_up(image_extents[1], div[1]),
323+
image_extents[2]};
305324
} else {
306325
return graph.create_global_wg_size(out);
307326
}
@@ -364,6 +383,10 @@ void add_conv2d_node(
364383
Conv2dParams extra_params =
365384
create_conv2d_params(graph, weight_data, kernel_params, transposed_val);
366385

386+
const bool stride_equals_dilation =
387+
(kernel_params.stride[0] == kernel_params.dilation[0] &&
388+
kernel_params.stride[1] == kernel_params.dilation[1]);
389+
367390
OutputParams out_params = {out_min_val, out_max_val};
368391

369392
check_conv2d_params(kernel_params, transposed_val);
@@ -374,9 +397,11 @@ void add_conv2d_node(
374397
/*prepack_weights = */ false,
375398
method,
376399
weight_data,
377-
clamp_out);
400+
clamp_out,
401+
stride_equals_dilation);
378402

379-
utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out);
403+
utils::uvec3 wg_size = create_conv2d_global_wg_size(
404+
graph, method, out, weight_data, stride_equals_dilation);
380405

381406
if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
382407
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};

backends/vulkan/test/op_tests/cases.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,39 @@ def get_conv_inputs():
348348
[0, 0],
349349
1,
350350
),
351+
(
352+
(1, 4, 234, 234),
353+
(4, 1, 3, 3),
354+
(4,),
355+
[2, 1],
356+
[1, 1],
357+
[1, 1],
358+
False,
359+
[0, 0],
360+
4,
361+
),
362+
(
363+
(1, 4, 234, 234),
364+
(4, 1, 3, 3),
365+
(4,),
366+
[1, 2],
367+
[1, 1],
368+
[1, 1],
369+
False,
370+
[0, 0],
371+
4,
372+
),
373+
(
374+
(1, 4, 234, 234),
375+
(4, 1, 3, 3),
376+
(4,),
377+
[2, 2],
378+
[1, 1],
379+
[1, 1],
380+
False,
381+
[0, 0],
382+
4,
383+
),
351384
]
352385
)
353386
return test_suite

0 commit comments

Comments
 (0)