Skip to content

Commit 62a10c8

Browse files
committed
Update on "[ET-VK][Ops] aten.avg_pool2d"
## The Operator `nn.Module` invocations of [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) get compiled to `aten.avg_pool2d.default` in the Edge Dialect, which carries the following signature. ``` - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor ``` ## Implementation This is a full implementation. We start with [LiteInterpreter's `avg_pool2d.glsl` logic](https://github.com/pytorch/pytorch/blob/9257a0698b57acc5607ee6fe31a16fdd93af1731/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl), which is incomplete, and cover `ceil_mode=True`, `count_include_pad=True`, and `divisor_override` cases for full support. As a result, the divisor's computation is now a bit complex. If needed, we can simplify it into separate shaders in the future. Differential Revision: [D57918523](https://our.internmc.facebook.com/intern/diff/D57918523/) [ghstack-poisoned]
2 parents ee636b9 + 62ea4c8 commit 62a10c8

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

backends/vulkan/test/op_tests/cases.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,21 +115,54 @@ def get_linear_inputs():
115115

116116

117117
def get_avg_pool2d_inputs():
118-
Test = namedtuple("VkAvgPoolTest", ["self", "kernel_size", "stride", "padding", "ceil_mode", "count_include_pad", "divisor_override"])
118+
Test = namedtuple(
119+
"VkAvgPoolTest",
120+
[
121+
"self",
122+
"kernel_size",
123+
"stride",
124+
"padding",
125+
"ceil_mode",
126+
"count_include_pad",
127+
"divisor_override",
128+
],
129+
)
119130
Test.__new__.__defaults__ = (None, None)
120131

121-
test_cases = [
122-
Test(self=(S, M1, M2), kernel_size=[2, 2], stride=[1, 1], padding=[0, 0], ceil_mode=False, count_include_pad=True, divisor_override=None),
123-
Test(self=(S, M1, M2), kernel_size=[2, 2], stride=[1, 1], padding=[0, 0], ceil_mode=False, count_include_pad=True, divisor_override=5),
124-
Test(self=(S, M1, M2), kernel_size=[5, 4], stride=[1, 1], padding=[2, 1], ceil_mode=False, count_include_pad=True, divisor_override=None),
125-
Test(self=(S, M1, M2), kernel_size=[4, 5], stride=[1, 1], padding=[2, 1], ceil_mode=False, count_include_pad=True, divisor_override=None),
126-
Test(self=(S, M1, M2), kernel_size=[5, 4], stride=[1, 1], padding=[2, 1], ceil_mode=True, count_include_pad=False, divisor_override=None),
127-
Test(self=(S, M1, M2), kernel_size=[4, 5], stride=[2, 2], padding=[2, 1], ceil_mode=True, count_include_pad=False, divisor_override=None),
128-
Test(self=(S, M1, M2), kernel_size=[5, 4], stride=[3, 1], padding=[2, 1], ceil_mode=False, count_include_pad=False, divisor_override=None),
129-
Test(self=(S, M1, M2), kernel_size=[4, 5], stride=[1, 3], padding=[2, 1], ceil_mode=False, count_include_pad=False, divisor_override=None),
130-
Test(self=(S, M1, M2), kernel_size=[5, 4], stride=[4, 4], padding=[2, 1], ceil_mode=True, count_include_pad=True, divisor_override=None),
131-
Test(self=(S, M1, M2), kernel_size=[4, 5], stride=[1, 1], padding=[2, 1], ceil_mode=True, count_include_pad=True, divisor_override=None),
132-
]
132+
test_cases = []
133+
134+
for ceil_mode in [True, False]:
135+
for count_include_pad in [True, False]:
136+
for divisor_override in [None, 5]:
137+
test_cases += [
138+
Test(
139+
self=(S, M1, M2),
140+
kernel_size=[2, 2],
141+
stride=[1, 1],
142+
padding=[0, 0],
143+
ceil_mode=ceil_mode,
144+
count_include_pad=count_include_pad,
145+
divisor_override=divisor_override,
146+
),
147+
Test(
148+
self=(S, M1, M2),
149+
kernel_size=[5, 4],
150+
stride=[3, 1],
151+
padding=[2, 1],
152+
ceil_mode=ceil_mode,
153+
count_include_pad=count_include_pad,
154+
divisor_override=divisor_override,
155+
),
156+
Test(
157+
self=(S, M1, M2),
158+
kernel_size=[4, 5],
159+
stride=[1, 3],
160+
padding=[2, 1],
161+
ceil_mode=ceil_mode,
162+
count_include_pad=count_include_pad,
163+
divisor_override=divisor_override,
164+
),
165+
]
133166

134167
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
135168
return test_suite

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import re
8-
from typing import Any, List, Tuple
8+
from typing import Any, List
99

1010
from torchgen.api import cpp
1111
from torchgen.api.types import CppSignatureGroup

0 commit comments

Comments
 (0)