Skip to content

Commit 556fd61

Browse files
committed
Update on "[ET-VK][Ops] aten.avg_pool2d"
## The Operator `nn.Module` invocations of [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) get compiled to `aten.avg_pool2d.default` in the Edge Dialect, which carries the following signature. ``` - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor ``` ## Implementation This is a full implementation. We start with [LiteInterpreter's `avg_pool2d.glsl` logic](https://github.com/pytorch/pytorch/blob/9257a0698b57acc5607ee6fe31a16fdd93af1731/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl), which is incomplete, and cover `ceil_mode=True`, `count_include_pad=True`, and `divisor_override` cases for full support. As a result, the divisor's computation is now a bit complex. If needed, we can simplify it into separate shaders in the future. Differential Revision: [D57918523](https://our.internmc.facebook.com/intern/diff/D57918523/) [ghstack-poisoned]
2 parents 7a57ea3 + a7540ec commit 556fd61

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def get_avg_pool2d_inputs():
165165
]
166166

167167
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
168+
test_suite.dtypes = ["at::kFloat"]
168169
return test_suite
169170

170171

0 commit comments

Comments
 (0)