Skip to content

Commit d12d4c2

Browse files
committed
[ET-VK] Fixing conv2d dw incorrect output when stride != dilation issue.
This diff moves current implementation of conv2d dw as a special case when stride equals dilation in the Vulkan backend of Executorch, since that's the only time this kind of caching is possible. If stride does not equal dilation the old implementation is used. Differential Revision: [D67908916](https://our.internmc.facebook.com/intern/diff/D67908916/) ghstack-source-id: 260756038 Pull Request resolved: #7571
1 parent fc2653b commit d12d4c2

File tree

4 files changed

+116
-12
lines changed

4 files changed

+116
-12
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#define TILE_SIZE ${TILE_SIZE}
1616

17+
#define STRIDE_EQ_DILATION ${STRIDE_EQ_DILATION}
18+
1719
#define BATCH_SIZE_X ${BATCH_SIZE_X}
1820

1921
#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
@@ -36,12 +38,12 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3638

3739
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3840

39-
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
40-
4141
/*
4242
* Computes a depthwise convolution. Each shader invocation calculates the
4343
* output at a single output location.
4444
*/
45+
46+
#if STRIDE_EQ_DILATION
4547
void main() {
4648
// x and y are divided by batch size to determine 3d position
4749
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
@@ -119,3 +121,38 @@ void main() {
119121
}
120122
}
121123
}
124+
125+
#else
126+
void main() {
127+
const ivec3 pos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y);
128+
129+
if (any(greaterThanEqual(pos, out_limits))) {
130+
return;
131+
}
132+
133+
// Compute the index of the top-left element of the overlay region. Negative
134+
// indices indicate that the top-left element is in a region added by padding.
135+
const ivec2 ipos = pos.xy * stride - padding;
136+
137+
// Compute the start and end of the input indices to load. Padding is assumed
138+
// to be constant 0 padding, so any reads from the padding region is skipped.
139+
const ivec2 start = ipos;
140+
const ivec2 end = ipos + overlay_region.xy;
141+
142+
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
143+
int kx = 0;
144+
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
145+
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
146+
// The weight kernel was rearranged such that every NxN filter is
147+
// flattened to fit in one row. Each filter was then stacked on top of
148+
// each other vertically.
149+
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
150+
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
151+
kx++;
152+
}
153+
}
154+
155+
imageStore(t_out, pos, op(sum, out_min, out_max));
156+
}
157+
158+
#endif

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ conv2d_dw_output_tile:
1212
TILE_SIZE: 3
1313
BATCH_SIZE_X: 4
1414
BATCH_SIZE_Y: 2
15+
STRIDE_EQ_DILATION: 0
1516
generate_variant_forall:
1617
DTYPE:
1718
- VALUE: half
@@ -25,3 +26,15 @@ conv2d_dw_output_tile:
2526
- NAME: conv2d_dw_output_tile_5x5_clamp
2627
OPERATOR: clamp(X, A, B)
2728
TILE_SIZE: 5
29+
- NAME: conv2d_dw_sed_output_tile_3x3
30+
STRIDE_EQ_DILATION: 1
31+
- NAME: conv2d_dw_sed_output_tile_3x3_clamp
32+
OPERATOR: clamp(X, A, B)
33+
STRIDE_EQ_DILATION: 1
34+
- NAME: conv2d_dw_sed_output_tile_5x5
35+
TILE_SIZE: 5
36+
STRIDE_EQ_DILATION: 1
37+
- NAME: conv2d_dw_sed_output_tile_5x5_clamp
38+
OPERATOR: clamp(X, A, B)
39+
TILE_SIZE: 5
40+
STRIDE_EQ_DILATION: 1

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,17 @@ vkapi::ShaderInfo get_conv2d_shader(
126126
const bool prepack_weights,
127127
const Conv2dMethod method,
128128
const ValueRef weight,
129-
const bool clamp_out = false) {
129+
const bool clamp_out = false,
130+
const bool stride_equals_dilation = false) {
130131
std::string kernel_name;
131132
kernel_name.reserve(kShaderNameReserve);
132133
switch (method) {
133134
case Conv2dMethod::Depthwise:
134135
kernel_name = "conv2d_dw";
135136
if (!prepack_weights) {
137+
if (stride_equals_dilation) {
138+
kernel_name += "_sed";
139+
}
136140
const auto& weight_sizes = graph.get_tref(weight)->sizes;
137141
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
138142
kernel_name += "_output_tile_3x3";
@@ -286,22 +290,33 @@ Conv2dMethod get_conv2d_method(
286290
return Conv2dMethod::SlidingWindow;
287291
}
288292

293+
utils::uvec2 get_conv2d_dw_dispatch_divisor(const std::vector<int64_t>& weight_sizes) {
294+
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
295+
return {4u, 2u};
296+
}
297+
if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) {
298+
return {4u, 2u};
299+
}
300+
return {4u, 2u};
301+
}
302+
289303
utils::uvec3 create_conv2d_global_wg_size(
290304
ComputeGraph& graph,
291305
const Conv2dMethod method,
292-
const ValueRef out) {
306+
const ValueRef out,
307+
const ValueRef weight_data,
308+
const bool stride_equals_dilation) {
293309
if (method == Conv2dMethod::Pointwise) {
294310
const utils::uvec3 image_extents = graph.logical_limits_of(out);
295311
return {
296312
utils::div_up(image_extents[0u], 2u),
297313
utils::div_up(image_extents[1u], 2u),
298314
image_extents[2u]};
299-
} else if (method == Conv2dMethod::Depthwise) {
300-
const utils::uvec3 image_extents = graph.logical_limits_of(out);
301-
return {
302-
utils::div_up(image_extents[0u], 4u),
303-
utils::div_up(image_extents[1u], 2u),
304-
image_extents[2u]};
315+
} else if (method == Conv2dMethod::Depthwise && stride_equals_dilation) {
316+
const utils::uvec3 image_extents = graph.create_global_wg_size(out);
317+
const utils::uvec2 div =
318+
get_conv2d_dw_dispatch_divisor(graph.get_tref(weight_data)->sizes);
319+
return {utils::div_up(image_extents[0], div[0]), utils::div_up(image_extents[1], div[1]), image_extents[2]};
305320
} else {
306321
return graph.create_global_wg_size(out);
307322
}
@@ -364,6 +379,10 @@ void add_conv2d_node(
364379
Conv2dParams extra_params =
365380
create_conv2d_params(graph, weight_data, kernel_params, transposed_val);
366381

382+
const bool stride_equals_dilation =
383+
(kernel_params.stride[0] == kernel_params.dilation[0] &&
384+
kernel_params.stride[1] == kernel_params.dilation[1]);
385+
367386
OutputParams out_params = {out_min_val, out_max_val};
368387

369388
check_conv2d_params(kernel_params, transposed_val);
@@ -374,9 +393,11 @@ void add_conv2d_node(
374393
/*prepack_weights = */ false,
375394
method,
376395
weight_data,
377-
clamp_out);
396+
clamp_out,
397+
stride_equals_dilation);
378398

379-
utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out);
399+
utils::uvec3 wg_size = create_conv2d_global_wg_size(
400+
graph, method, out, weight_data, stride_equals_dilation);
380401

381402
if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
382403
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};

backends/vulkan/test/op_tests/cases.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,39 @@ def get_conv_inputs():
348348
[0, 0],
349349
1,
350350
),
351+
(
352+
(1, 4, 234, 234),
353+
(4, 1, 3, 3),
354+
(4,),
355+
[2, 1],
356+
[1, 1],
357+
[1, 1],
358+
False,
359+
[0, 0],
360+
4,
361+
),
362+
(
363+
(1, 4, 234, 234),
364+
(4, 1, 3, 3),
365+
(4,),
366+
[1, 2],
367+
[1, 1],
368+
[1, 1],
369+
False,
370+
[0, 0],
371+
4,
372+
),
373+
(
374+
(1, 4, 234, 234),
375+
(4, 1, 3, 3),
376+
(4,),
377+
[2, 2],
378+
[1, 1],
379+
[1, 1],
380+
False,
381+
[0, 0],
382+
4,
383+
),
351384
]
352385
)
353386
return test_suite

0 commit comments

Comments
 (0)