-
Notifications
You must be signed in to change notification settings - Fork 607
Commit 90b2a01
committed
Update base for Update on "[ET-VK][Ops] aten.avg_pool2d"
## The Operator
`nn.Module` invocations of [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) get compiled to `aten.avg_pool2d.default` in the Edge Dialect, which carries the following signature.
```
- func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
```
## Implementation
This is a full implementation. We start with [LiteInterpreter's `avg_pool2d.glsl` logic](https://github.com/pytorch/pytorch/blob/9257a0698b57acc5607ee6fe31a16fdd93af1731/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl), which is incomplete, and cover `ceil_mode=True`, `count_include_pad=True`, and `divisor_override` cases for full support. As a result, the divisor's computation is now a bit complex. If needed, we can simplify it into separate shaders in the future.
Differential Revision: [D57918523](https://our.internmc.facebook.com/intern/diff/D57918523/)
[ghstack-poisoned]1 parent cebf8f7 commit 90b2a01Copy full SHA for 90b2a01
File tree
Expand file treeCollapse file tree
0 file changed
+0
-0
lines changedFilter options
Expand file treeCollapse file tree
0 file changed
+0
-0
lines changed
0 commit comments