Skip to content

Commit a4ffb2a

Browse files
committed
Update on "[ET-VK][Ops] aten.avg_pool2d"
## The Operator `nn.Module` invocations of [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) get compiled to `aten.avg_pool2d.default` in the Edge Dialect, which carries the following signature. ``` - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor ``` ## Implementation This is a full implementation. We start with [LiteInterpreter's `avg_pool2d.glsl` logic](https://github.com/pytorch/pytorch/blob/9257a0698b57acc5607ee6fe31a16fdd93af1731/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl), which is incomplete, and cover `ceil_mode=True`, `count_include_pad=True`, and `divisor_override` cases for full support. As a result, the divisor's computation is now a bit complex. If needed, we can simplify it into separate shaders in the future. Differential Revision: [D57918523](https://our.internmc.facebook.com/intern/diff/D57918523/) [ghstack-poisoned]
2 parents 62a10c8 + cebf8f7 commit a4ffb2a

File tree

1 file changed

+3
-4
lines changed
  • backends/vulkan/runtime/graph/ops/impl

1 file changed

+3
-4
lines changed

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,12 @@ void max_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
124124
//
125125

126126
struct DivisorParams final {
127-
int32_t divisor;
127+
int32_t divisor_override;
128128
bool count_include_pad;
129129
};
130130

131131
DivisorParams create_divisor_params(
132132
ComputeGraph& graph,
133-
const api::utils::ivec2& kernel_size,
134133
const ValueRef divisor_override,
135134
const ValueRef count_include_pad) {
136135
return {
@@ -165,8 +164,8 @@ void add_avg_pool2d_node(
165164
Kernel2dParams kernel_params =
166165
create_kernel2d_params(graph, kernel_size, stride, padding);
167166

168-
DivisorParams divisor_params = create_divisor_params(
169-
graph, kernel_params.kernel_size, divisor_override, count_include_pad);
167+
DivisorParams divisor_params =
168+
create_divisor_params(graph, divisor_override, count_include_pad);
170169

171170
graph.execute_nodes().emplace_back(new ExecuteNode(
172171
graph,

0 commit comments

Comments
 (0)