Skip to content

Commit aa90bff

Browse files
committed
Update on "[ET-VK][Ops] aten.avg_pool2d"
## The Operator `nn.Module` invocations of [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) get compiled to `aten.avg_pool2d.default` in the Edge Dialect, which carries the following signature. ``` - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor ``` ## Implementation This is a full implementation. We start with [LiteInterpreter's `avg_pool2d.glsl` logic](https://github.com/pytorch/pytorch/blob/9257a0698b57acc5607ee6fe31a16fdd93af1731/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl), which is incomplete, and cover `ceil_mode=True`, `count_include_pad=True`, and `divisor_override` cases for full support. As a result, the divisor's computation is now a bit complex. If needed, we can simplify it into separate shaders in the future. Differential Revision: [D57918523](https://our.internmc.facebook.com/intern/diff/D57918523/) [ghstack-poisoned]
2 parents 556fd61 + 0b63fff commit aa90bff

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

backends/vulkan/test/op_tests/cases.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -144,28 +144,28 @@ def get_avg_pool2d_inputs():
144144
count_include_pad=count_include_pad,
145145
divisor_override=divisor_override,
146146
),
147-
Test(
148-
self=(S, M1, M2),
149-
kernel_size=[5, 4],
150-
stride=[3, 1],
151-
padding=[2, 1],
152-
ceil_mode=ceil_mode,
153-
count_include_pad=count_include_pad,
154-
divisor_override=divisor_override,
155-
),
156-
Test(
157-
self=(S, M1, M2),
158-
kernel_size=[4, 5],
159-
stride=[1, 3],
160-
padding=[2, 1],
161-
ceil_mode=ceil_mode,
162-
count_include_pad=count_include_pad,
163-
divisor_override=divisor_override,
164-
),
165147
]
166-
148+
test_cases += [
149+
Test(
150+
self=(S, M1, M2),
151+
kernel_size=[5, 4],
152+
stride=[3, 1],
153+
padding=[2, 1],
154+
ceil_mode=ceil_mode,
155+
count_include_pad=True,
156+
divisor_override=None,
157+
),
158+
Test(
159+
self=(S, M1, M2),
160+
kernel_size=[4, 5],
161+
stride=[1, 3],
162+
padding=[2, 1],
163+
ceil_mode=ceil_mode,
164+
count_include_pad=True,
165+
divisor_override=None,
166+
),
167+
]
167168
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
168-
test_suite.dtypes = ["at::kFloat"]
169169
return test_suite
170170

171171

0 commit comments

Comments
 (0)