Skip to content

Commit a7cde32

Browse files
authored
[ET-VK][ez] Support convolutions with padding > 0 and dilation > 1 (#10148)
## Context As title. ## Changes Modified convolution shader to set starting input position correctly when padding and dilation are both greater than 0. Removed safeguard check for padding = 0 when dilation > 1 in C++ implementation. Differential Revision: [D72879342](https://our.internmc.facebook.com/intern/diff/D72879342/)
1 parent 85485b8 commit a7cde32

File tree

3 files changed

+193
-151
lines changed

3 files changed

+193
-151
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,17 @@ void main() {
4747

4848
// Compute the start and end of the input indices to load. Padding is assumed
4949
// to be constant 0 padding, so reads from the padding region are skipped.
50-
const ivec2 start = max(ivec2(0), ipos);
50+
ivec2 start = ipos;
51+
if (start.x < 0) {
52+
// number of "steps" to get to >= zero is div_up(-start, dilation)
53+
int num_steps = ((-ipos.x) + dilation.x - 1) / dilation.x;
54+
start.x = ipos.x + num_steps * dilation.x;
55+
}
56+
if (start.y < 0) {
57+
// number of "steps" to get to >= zero is div_up(-start, dilation)
58+
int num_steps = ((-ipos.y) + dilation.y - 1) / dilation.y;
59+
start.y = ipos.y + num_steps * dilation.y;
60+
}
5161
const ivec2 end = min(ipos + overlay_region.xy, ivec2(in_sizes.xy));
5262
// Compute the start of the kernel based on how far we are skipping ahead when
5363
// reading the input. Note that these are "canonical" indices.

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,6 @@ void check_conv2d_params(const Kernel2dParams& p, const bool transposed) {
262262
"aten.convolution.default: transposed = true, dilation > 1 is not supported yet!");
263263
}
264264
}
265-
if ((p.padding[0] > 0 && p.kernel_size[0] > 1 && p.dilation[0] > 1) ||
266-
(p.padding[1] > 0 && p.kernel_size[1] > 1 && p.dilation[1] > 1)) {
267-
VK_THROW(
268-
"aten.convolution.default: padding > 0 while dilation, kernel_size > 1 is not supported yet!");
269-
}
270265
}
271266

272267
Conv2dMethod get_conv2d_method(

backends/vulkan/test/op_tests/cases.py

Lines changed: 182 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -226,153 +226,190 @@ def get_max_pool2d_inputs():
226226

227227
@register_test_suite("aten.convolution.default")
228228
def get_conv_inputs():
229-
test_suite = VkTestSuite(
229+
Test = namedtuple(
230+
"ConvTest",
230231
[
231-
(
232-
(1, 6, 40, 50),
233-
(8, 6, 3, 3),
234-
(8,),
235-
[1, 2],
236-
[2, 3],
237-
[1, 1],
238-
False,
239-
[0, 0],
240-
1,
241-
),
242-
(
243-
(1, 6, 40, 50),
244-
(6, 8, 3, 3),
245-
(8,),
246-
[1, 2],
247-
[2, 3],
248-
[1, 1],
249-
True,
250-
[0, 1],
251-
1,
252-
),
253-
(
254-
(1, 8, 72, 96),
255-
(8, 1, 3, 3),
256-
(8,),
257-
[1, 1],
258-
[1, 1],
259-
[1, 1],
260-
False,
261-
[0, 0],
262-
8,
263-
),
264-
(
265-
(1, 8, 72, 96),
266-
(8, 8, 1, 1),
267-
(8,),
268-
[1, 1],
269-
[1, 1],
270-
[1, 1],
271-
False,
272-
[0, 0],
273-
1,
274-
),
275-
(
276-
(1, 6, 40, 50),
277-
(8, 6, 3, 3),
278-
None,
279-
[1, 2],
280-
[2, 3],
281-
[1, 1],
282-
False,
283-
[0, 0],
284-
1,
285-
),
286-
(
287-
(1, 6, 7),
288-
(6, 1, 3),
289-
(6,),
290-
[1],
291-
[0],
292-
[1],
293-
False,
294-
[0],
295-
6,
296-
),
297-
(
298-
(2, 20, 30),
299-
(10, 4, 6),
300-
(10,),
301-
[5],
302-
[5],
303-
[3],
304-
False,
305-
[0],
306-
5,
307-
),
308-
(
309-
(1, 9, 11),
310-
(9, 1, 3),
311-
None,
312-
[1],
313-
[0],
314-
[1],
315-
False,
316-
[0],
317-
9,
318-
),
319-
(
320-
(5, 15, 30),
321-
(20, 3, 3),
322-
None,
323-
[3],
324-
[5],
325-
[7],
326-
False,
327-
[0],
328-
5,
329-
),
330-
(
331-
(1, 16, 672, 512),
332-
(64, 16, 1, 1),
333-
(64,),
334-
[1, 1],
335-
[0, 0],
336-
[1, 1],
337-
False,
338-
[0, 0],
339-
1,
340-
),
341-
(
342-
(1, 4, 234, 234),
343-
(4, 1, 3, 3),
344-
(4,),
345-
[2, 1],
346-
[1, 1],
347-
[1, 1],
348-
False,
349-
[0, 0],
350-
4,
351-
),
352-
(
353-
(1, 4, 234, 234),
354-
(4, 1, 3, 3),
355-
(4,),
356-
[1, 2],
357-
[1, 1],
358-
[1, 1],
359-
False,
360-
[0, 0],
361-
4,
362-
),
363-
(
364-
(1, 4, 234, 234),
365-
(4, 1, 3, 3),
366-
(4,),
367-
[2, 2],
368-
[1, 1],
369-
[1, 1],
370-
False,
371-
[0, 0],
372-
4,
373-
),
374-
]
232+
"self",
233+
"weight",
234+
"bias",
235+
"stride",
236+
"padding",
237+
"dilation",
238+
"transposed",
239+
"output_padding",
240+
"groups",
241+
],
242+
)
243+
Test.__new__.__defaults__ = (
244+
None,
245+
None,
246+
None,
247+
[1, 1],
248+
[0, 0],
249+
[1, 1],
250+
False,
251+
[9, 0],
252+
1,
375253
)
254+
test_cases = []
255+
test_cases = [
256+
Test(
257+
self=(1, 6, 40, 50),
258+
weight=(8, 6, 3, 3),
259+
bias=(8,),
260+
stride=[1, 2],
261+
padding=[2, 3],
262+
dilation=[1, 1],
263+
transposed=False,
264+
output_padding=[0, 0],
265+
groups=1,
266+
),
267+
Test(
268+
self=(1, 6, 40, 50),
269+
weight=(6, 8, 3, 3),
270+
bias=(8,),
271+
stride=[1, 2],
272+
padding=[2, 3],
273+
dilation=[1, 1],
274+
transposed=True,
275+
output_padding=[0, 1],
276+
groups=1,
277+
),
278+
Test(
279+
self=(1, 8, 72, 96),
280+
weight=(8, 1, 3, 3),
281+
bias=(8,),
282+
stride=[1, 1],
283+
padding=[1, 1],
284+
dilation=[1, 1],
285+
transposed=False,
286+
output_padding=[0, 0],
287+
groups=8,
288+
),
289+
Test(
290+
self=(1, 8, 72, 96),
291+
weight=(8, 8, 1, 1),
292+
bias=(8,),
293+
stride=[1, 1],
294+
padding=[1, 1],
295+
dilation=[1, 1],
296+
transposed=False,
297+
output_padding=[0, 0],
298+
groups=1,
299+
),
300+
Test(
301+
self=(1, 6, 40, 50),
302+
weight=(8, 6, 3, 3),
303+
bias=None,
304+
stride=[1, 2],
305+
padding=[2, 3],
306+
dilation=[1, 1],
307+
transposed=False,
308+
output_padding=[0, 0],
309+
groups=1,
310+
),
311+
Test(
312+
self=(1, 6, 7),
313+
weight=(6, 1, 3),
314+
bias=(6,),
315+
stride=[1],
316+
padding=[0],
317+
dilation=[1],
318+
transposed=False,
319+
output_padding=[0],
320+
groups=6,
321+
),
322+
Test(
323+
self=(2, 20, 30),
324+
weight=(10, 4, 6),
325+
bias=(10,),
326+
stride=[5],
327+
padding=[5],
328+
dilation=[3],
329+
transposed=False,
330+
output_padding=[0],
331+
groups=5,
332+
),
333+
Test(
334+
self=(1, 9, 11),
335+
weight=(9, 1, 3),
336+
bias=None,
337+
stride=[1],
338+
padding=[0],
339+
dilation=[1],
340+
transposed=False,
341+
output_padding=[0],
342+
groups=9,
343+
),
344+
Test(
345+
self=(5, 15, 30),
346+
weight=(20, 3, 3),
347+
bias=None,
348+
stride=[3],
349+
padding=[5],
350+
dilation=[7],
351+
transposed=False,
352+
output_padding=[0],
353+
groups=5,
354+
),
355+
Test(
356+
self=(1, 16, 672, 512),
357+
weight=(64, 16, 1, 1),
358+
bias=(64,),
359+
stride=[1, 1],
360+
padding=[0, 0],
361+
dilation=[1, 1],
362+
transposed=False,
363+
output_padding=[0, 0],
364+
groups=1,
365+
),
366+
Test(
367+
self=(1, 4, 234, 234),
368+
weight=(4, 1, 3, 3),
369+
bias=(4,),
370+
stride=[2, 1],
371+
padding=[1, 1],
372+
dilation=[1, 1],
373+
transposed=False,
374+
output_padding=[0, 0],
375+
groups=4,
376+
),
377+
Test(
378+
self=(1, 4, 234, 234),
379+
weight=(4, 1, 3, 3),
380+
bias=(4,),
381+
stride=[1, 2],
382+
padding=[1, 1],
383+
dilation=[1, 1],
384+
transposed=False,
385+
output_padding=[0, 0],
386+
groups=4,
387+
),
388+
Test(
389+
self=(1, 4, 234, 234),
390+
weight=(4, 1, 3, 3),
391+
bias=(4,),
392+
stride=[2, 2],
393+
padding=[1, 1],
394+
dilation=[1, 1],
395+
transposed=False,
396+
output_padding=[0, 0],
397+
groups=4,
398+
),
399+
Test(
400+
self=(1, 8, 90, 77),
401+
weight=(1, 8, 3, 3),
402+
bias=(1,),
403+
stride=[1, 1],
404+
padding=[2, 2],
405+
dilation=[2, 2],
406+
transposed=False,
407+
output_padding=[0, 0],
408+
groups=1,
409+
),
410+
]
411+
412+
test_suite = VkTestSuite(test_cases)
376413
return test_suite
377414

378415

0 commit comments

Comments
 (0)