Skip to content

Commit ee636b9

Browse files
committed
Update on "[ET-VK][Ops] aten.avg_pool2d"
## The Operator `nn.Module` invocations of [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) get compiled to `aten.avg_pool2d.default` in the Edge Dialect, which carries the following signature. ``` - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor ``` ## Implementation This is a full implementation. We start with [LiteInterpreter's `avg_pool2d.glsl` logic](https://github.com/pytorch/pytorch/blob/9257a0698b57acc5607ee6fe31a16fdd93af1731/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl), which is incomplete, and cover `ceil_mode=True`, `count_include_pad=True`, and `divisor_override` cases for full support. As a result, the divisor's computation is now a bit complex. If needed, we can simplify it into separate shaders in the future. Differential Revision: [D57918523](https://our.internmc.facebook.com/intern/diff/D57918523/) [ghstack-poisoned]
2 parents f45a9e2 + 6fb7746 commit ee636b9

File tree

2 files changed

+1
-5
lines changed

2 files changed

+1
-5
lines changed

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,10 @@ std::vector<int64_t> calc_out_sizes_hw(
156156
make_ivec2_kernel_size(graph, weight, kernel_size_only);
157157
const auto stride = make_ivec2_from_list(graph, args[0]);
158158
const auto padding = make_ivec2_from_list(graph, args[1]);
159-
160159
const auto dilation = args[2] == kDummyValueRef
161160
? api::utils::ivec2{1, 1}
162161
: make_ivec2_from_list(graph, args[2]);
162+
163163
if (transposed) {
164164
const auto output_padding = make_ivec2_from_list(graph, args[3]);
165165
return calc_transpose_out_sizes_hw(

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ def assert_outputs_equal(
7878
)
7979
else:
8080
# If one output, eager returns tensor while executor tuple of size 1
81-
# print("ET-VK")
82-
# print(model_output[0])
83-
# print("Eager-mode")
84-
# print(ref_output)
8581
self.assertTrue(
8682
torch.allclose(
8783
model_output[0],

0 commit comments

Comments
 (0)