Skip to content

Commit 7a57ea3

Browse files
committed
Update on "[ET-VK][Ops] aten.avg_pool2d"
## The Operator `nn.Module` invocations of [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) get compiled to `aten.avg_pool2d.default` in the Edge Dialect, which carries the following signature. ``` - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor ``` ## Implementation This is a full implementation. We start with [LiteInterpreter's `avg_pool2d.glsl` logic](https://github.com/pytorch/pytorch/blob/9257a0698b57acc5607ee6fe31a16fdd93af1731/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl), which is incomplete, and cover `ceil_mode=True`, `count_include_pad=True`, and `divisor_override` cases for full support. As a result, the divisor's computation is now a bit complex. If needed, we can simplify it into separate shaders in the future. Differential Revision: [D57918523](https://our.internmc.facebook.com/intern/diff/D57918523/) [ghstack-poisoned]
2 parents a4ffb2a + 90b2a01 commit 7a57ea3

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
#version 450 core
1010

1111
#define PRECISION ${PRECISION}
12-
#define FLT_MIN -3.402823466e+38
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
1314

1415
#include "indexing_utils.h"
1516

@@ -36,7 +37,7 @@ void main() {
3637
const ivec2 start = max(ivec2(0), ipos);
3738
const ivec2 end = min(ipos + kernel_size, ivec2(in_sizes.xy));
3839

39-
vec4 sum = vec4(0);
40+
VEC4_T sum = VEC4_T(0);
4041
for (int y = start.y; y < end.y; ++y) {
4142
for (int x = start.x; x < end.x; ++x) {
4243
sum += texelFetch(t_in, ivec3(x, y, pos.z), 0);

backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ avg_pool2d:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16+
- VALUE: int
1617
shader_variants:
1718
- NAME: avg_pool2d

0 commit comments

Comments
 (0)