Skip to content

Commit fc44904

Browse files
committed
[ET-VK][Ops] aten.avg_pool2d
Pull Request resolved: #3770 ## The Operator `nn.Module` invocations of [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) get compiled to `aten.avg_pool2d.default` in the Edge Dialect, which carries the following signature. ``` - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor ``` ## Implementation This is a full C-packing implementation including dynamic shape support. We start with [LiteInterpreter's `avg_pool2d.glsl` logic](https://github.com/pytorch/pytorch/blob/9257a0698b57acc5607ee6fe31a16fdd93af1731/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl), which is incomplete, and cover `ceil_mode=True`, `count_include_pad=True`, and `divisor_override` cases for full support. As a result, the divisor's computation is now a bit complex. If needed, we can simplify it into separate shaders in the future. ghstack-source-id: 228344034 Differential Revision: [D57918523](https://our.internmc.facebook.com/intern/diff/D57918523/)
1 parent dfdea20 commit fc44904

File tree

8 files changed

+300
-19
lines changed

8 files changed

+300
-19
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __contains__(self, op):
7373
]
7474

7575
POOLING_OPS = [
76+
exir_ops.edge.aten.avg_pool2d.default,
7677
exir_ops.edge.aten.max_pool2d_with_indices.default,
7778
]
7879

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
#include "indexing_utils.h"
16+
17+
layout(std430) buffer;
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
21+
${layout_declare_ubo(2, "ivec3", "out_limits")}
22+
${layout_declare_ubo(3, "ivec4", "in_sizes")}
23+
${layout_declare_ubo(4, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
24+
${layout_declare_ubo(5, "int", "divisor_override", "int", "count_include_pad")}
25+
26+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
27+
28+
void main() {
29+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
30+
31+
if (any(greaterThanEqual(pos, out_limits))) {
32+
return;
33+
}
34+
35+
const ivec2 ipos = pos.xy * stride - padding;
36+
37+
const ivec2 start = max(ivec2(0), ipos);
38+
const ivec2 end = min(ipos + kernel_size, ivec2(in_sizes.xy));
39+
40+
VEC4_T sum = VEC4_T(0);
41+
for (int y = start.y; y < end.y; ++y) {
42+
for (int x = start.x; x < end.x; ++x) {
43+
sum += texelFetch(t_in, ivec3(x, y, pos.z), 0);
44+
}
45+
}
46+
47+
int div;
48+
if (divisor_override > 0) {
49+
div = divisor_override;
50+
} else if (count_include_pad > 0) {
51+
ivec2 empty = max(ipos + kernel_size - padding - ivec2(in_sizes.xy), ivec2(0));
52+
div = (kernel_size.y - empty.y) * (kernel_size.x - empty.x);
53+
} else {
54+
div = (end.y - start.y) * (end.x - start.x);
55+
}
56+
imageStore(t_out, pos, sum / div);
57+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
avg_pool2d:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
NDIM: 3
11+
STORAGE: texture3d
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
- VALUE: float
16+
- VALUE: int
17+
shader_variants:
18+
- NAME: avg_pool2d

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,18 @@
1717

1818
namespace vkcompute {
1919

20-
void resize_max_pool2d_node(
20+
void check_pool2d_args(const vTensor& in, const vTensor& out) {
21+
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
22+
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
23+
}
24+
25+
void resize_pool2d_node(
2126
ComputeGraph* graph,
2227
const std::vector<ArgGroup>& args,
2328
const std::vector<ValueRef>& extra_args) {
29+
bool is_max_pool2d = extra_args[3] != kDummyValueRef;
30+
2431
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
25-
vTensorPtr indices = graph->get_tensor(args[0].refs[1]);
2632
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
2733

2834
size_t ndim = self->sizes().size();
@@ -45,14 +51,17 @@ void resize_max_pool2d_node(
4551
new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);
4652

4753
out->virtual_resize(new_out_sizes);
48-
indices->virtual_resize(new_out_sizes);
49-
}
5054

51-
void check_max_pool2d_args(const vTensor& in, const vTensor& out) {
52-
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
53-
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
55+
if (is_max_pool2d) {
56+
vTensorPtr indices = graph->get_tensor(args[0].refs[1]);
57+
indices->virtual_resize(new_out_sizes);
58+
}
5459
}
5560

61+
//
62+
// max_pool2d
63+
//
64+
5665
void add_max_pool2d_node(
5766
ComputeGraph& graph,
5867
const ValueRef in,
@@ -68,7 +77,7 @@ void add_max_pool2d_node(
6877
const auto out_val = graph.get_value_list(out);
6978
vTensorPtr t_out = graph.get_tensor(out_val->at(0));
7079

71-
check_max_pool2d_args(*t_in, *t_out);
80+
check_pool2d_args(*t_in, *t_out);
7281

7382
api::utils::uvec3 global_size = t_out->image_extents();
7483
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
@@ -101,7 +110,7 @@ void add_max_pool2d_node(
101110
// Specialization Constants
102111
{},
103112
// Resizing Logic
104-
resize_max_pool2d_node,
113+
resize_pool2d_node,
105114
{kernel_size, stride, padding, dilation, ceil_mode}));
106115
}
107116

@@ -110,7 +119,92 @@ void max_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
110119
graph, args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
111120
}
112121

122+
//
123+
// avg_pool2d
124+
//
125+
126+
struct DivisorParams final {
127+
int32_t divisor_override;
128+
bool count_include_pad;
129+
};
130+
131+
DivisorParams create_divisor_params(
132+
ComputeGraph& graph,
133+
const ValueRef divisor_override,
134+
const ValueRef count_include_pad) {
135+
return {
136+
graph.val_is_int(divisor_override)
137+
? static_cast<int32_t>(graph.get_int(divisor_override))
138+
: 0,
139+
graph.get_bool(count_include_pad)};
140+
}
141+
142+
void add_avg_pool2d_node(
143+
ComputeGraph& graph,
144+
const ValueRef in,
145+
const ValueRef kernel_size,
146+
const ValueRef stride,
147+
const ValueRef padding,
148+
const ValueRef ceil_mode,
149+
const ValueRef count_include_pad,
150+
const ValueRef divisor_override,
151+
const ValueRef out) {
152+
ValueRef arg = prepack_if_tensor_ref(graph, in);
153+
vTensorPtr t_in = graph.get_tensor(arg);
154+
vTensorPtr t_out = graph.get_tensor(out);
155+
156+
check_pool2d_args(*t_in, *t_out);
157+
158+
api::utils::uvec3 global_size = t_out->image_extents();
159+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
160+
161+
std::string kernel_name("avg_pool2d");
162+
add_dtype_suffix(kernel_name, *t_out);
163+
164+
Kernel2dParams kernel_params =
165+
create_kernel2d_params(graph, kernel_size, stride, padding);
166+
167+
DivisorParams divisor_params =
168+
create_divisor_params(graph, divisor_override, count_include_pad);
169+
170+
graph.execute_nodes().emplace_back(new ExecuteNode(
171+
graph,
172+
VK_KERNEL_FROM_STR(kernel_name),
173+
global_size,
174+
local_size,
175+
// Inputs and Outputs
176+
{{out, api::MemoryAccessType::WRITE}, {arg, api::MemoryAccessType::READ}},
177+
// Shader params buffers
178+
{t_out->texture_limits_ubo(),
179+
t_in->sizes_ubo(),
180+
graph.create_params_buffer(kernel_params),
181+
graph.create_params_buffer(divisor_params)},
182+
// Specialization Constants
183+
{},
184+
// Resizing Logic
185+
resize_pool2d_node,
186+
{kernel_size,
187+
stride,
188+
padding,
189+
/*dilation= */ kDummyValueRef,
190+
ceil_mode}));
191+
}
192+
193+
void avg_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
194+
return add_avg_pool2d_node(
195+
graph,
196+
args[0],
197+
args[1],
198+
args[2],
199+
args[3],
200+
args[4],
201+
args[5],
202+
args[6],
203+
args[7]);
204+
}
205+
113206
REGISTER_OPERATORS {
207+
VK_REGISTER_OP(aten.avg_pool2d.default, avg_pool2d);
114208
VK_REGISTER_OP(aten.max_pool2d_with_indices.default, max_pool2d);
115209
}
116210

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,19 @@ Kernel2dParams create_kernel2d_params(
4141
};
4242
}
4343

44+
Kernel2dParams create_kernel2d_params(
45+
ComputeGraph& graph,
46+
const ValueRef kernel_size,
47+
const ValueRef stride,
48+
const ValueRef padding) {
49+
return {
50+
make_ivec2_kernel_size(graph, kernel_size, /*kernel_size_only = */ true),
51+
make_ivec2_from_list(graph, stride),
52+
make_ivec2_from_list(graph, padding),
53+
{},
54+
};
55+
}
56+
4457
int64_t calc_out_size(
4558
const int64_t in_size,
4659
const int64_t kernel_size,
@@ -143,7 +156,9 @@ std::vector<int64_t> calc_out_sizes_hw(
143156
make_ivec2_kernel_size(graph, weight, kernel_size_only);
144157
const auto stride = make_ivec2_from_list(graph, args[0]);
145158
const auto padding = make_ivec2_from_list(graph, args[1]);
146-
const auto dilation = make_ivec2_from_list(graph, args[2]);
159+
const auto dilation = args[2] == kDummyValueRef
160+
? api::utils::ivec2{1, 1}
161+
: make_ivec2_from_list(graph, args[2]);
147162

148163
if (transposed) {
149164
const auto output_padding = make_ivec2_from_list(graph, args[3]);

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,6 @@
1616

1717
namespace vkcompute {
1818

19-
struct Kernel2dParams final {
20-
api::utils::ivec2 kernel_size;
21-
api::utils::ivec2 stride;
22-
api::utils::ivec2 padding;
23-
api::utils::ivec2 dilation;
24-
};
25-
2619
struct Kernel1dParams final {
2720
int kernel_size;
2821
int stride;
@@ -32,6 +25,13 @@ struct Kernel1dParams final {
3225
int out_group_size;
3326
};
3427

28+
struct Kernel2dParams final {
29+
api::utils::ivec2 kernel_size;
30+
api::utils::ivec2 stride;
31+
api::utils::ivec2 padding;
32+
api::utils::ivec2 dilation;
33+
};
34+
3535
Kernel2dParams create_kernel2d_params(
3636
ComputeGraph& graph,
3737
const ValueRef weight,
@@ -40,6 +40,12 @@ Kernel2dParams create_kernel2d_params(
4040
const ValueRef padding,
4141
const ValueRef dilation);
4242

43+
Kernel2dParams create_kernel2d_params(
44+
ComputeGraph& graph,
45+
const ValueRef kernel_size,
46+
const ValueRef stride,
47+
const ValueRef padding);
48+
4349
int64_t calc_out_size(
4450
const int64_t in_size,
4551
const int64_t kernel_size,

backends/vulkan/test/op_tests/cases.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,62 @@ def get_linear_inputs():
114114
return test_suite
115115

116116

117-
def get_pool2d_inputs():
117+
def get_avg_pool2d_inputs():
118+
Test = namedtuple(
119+
"VkAvgPoolTest",
120+
[
121+
"self",
122+
"kernel_size",
123+
"stride",
124+
"padding",
125+
"ceil_mode",
126+
"count_include_pad",
127+
"divisor_override",
128+
],
129+
)
130+
Test.__new__.__defaults__ = (None, None)
131+
132+
test_cases = []
133+
134+
for ceil_mode in [True, False]:
135+
for count_include_pad in [True, False]:
136+
for divisor_override in [None, 5]:
137+
test_cases += [
138+
Test(
139+
self=(S, M1, M2),
140+
kernel_size=[2, 2],
141+
stride=[1, 1],
142+
padding=[0, 0],
143+
ceil_mode=ceil_mode,
144+
count_include_pad=count_include_pad,
145+
divisor_override=divisor_override,
146+
),
147+
Test(
148+
self=(S, M1, M2),
149+
kernel_size=[5, 4],
150+
stride=[3, 1],
151+
padding=[2, 1],
152+
ceil_mode=ceil_mode,
153+
count_include_pad=count_include_pad,
154+
divisor_override=divisor_override,
155+
),
156+
Test(
157+
self=(S, M1, M2),
158+
kernel_size=[4, 5],
159+
stride=[1, 3],
160+
padding=[2, 1],
161+
ceil_mode=ceil_mode,
162+
count_include_pad=count_include_pad,
163+
divisor_override=divisor_override,
164+
),
165+
]
166+
167+
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
168+
test_suite.dtypes = ["at::kFloat"]
169+
return test_suite
170+
171+
172+
def get_max_pool2d_inputs():
118173
test_suite = VkTestSuite(
119174
[
120175
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
@@ -869,7 +924,8 @@ def get_arange_inputs():
869924
"aten.bmm.default": get_bmm_inputs(),
870925
"aten.mm.default": get_mm_inputs(),
871926
"aten.linear.default": get_linear_inputs(),
872-
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(),
927+
"aten.avg_pool2d.default": get_avg_pool2d_inputs(),
928+
"aten.max_pool2d_with_indices.default": get_max_pool2d_inputs(),
873929
"aten.convolution.default": get_conv_inputs(),
874930
"aten.native_layer_norm.default": get_native_layer_norm_inputs(),
875931
"aten.full.default": get_full_inputs(),

0 commit comments

Comments
 (0)