Skip to content

Commit 50f8971

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Implement avg_pool2d
Reviewed By: guangy10, kirklandsign Differential Revision: D48483282 fbshipit-source-id: 5197acbc525bc2fd530ac92d542c2e15816b21d5
1 parent 91bf785 commit 50f8971

File tree

8 files changed

+1367
-53
lines changed

8 files changed

+1367
-53
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cstring>
10+
11+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
18+
using Tensor = exec_aten::Tensor;
19+
using ScalarType = exec_aten::ScalarType;
20+
using IntArrayRef = exec_aten::ArrayRef<int64_t>;
21+
22+
Tensor& avg_pool2d_out(
23+
RuntimeContext& ctx,
24+
const Tensor& in,
25+
IntArrayRef kernel_size,
26+
IntArrayRef stride,
27+
IntArrayRef padding,
28+
bool ceil_mode,
29+
bool count_include_pad,
30+
exec_aten::optional<int64_t> divisor_override,
31+
Tensor& out) {
32+
ET_KERNEL_CHECK(
33+
ctx,
34+
check_avg_pool2d_args(
35+
in,
36+
kernel_size,
37+
stride,
38+
padding,
39+
ceil_mode,
40+
count_include_pad,
41+
divisor_override,
42+
out),
43+
InvalidArgument,
44+
out);
45+
46+
size_t output_ndim = 0;
47+
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
48+
get_avg_pool2d_out_target_size(
49+
in, kernel_size, stride, padding, ceil_mode, output_sizes, &output_ndim);
50+
51+
ET_KERNEL_CHECK(
52+
ctx,
53+
output_size_is_valid({output_sizes, output_ndim}),
54+
InvalidArgument,
55+
out);
56+
57+
ET_KERNEL_CHECK(
58+
ctx,
59+
resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
60+
InvalidArgument,
61+
out);
62+
63+
ScalarType in_type = in.scalar_type();
64+
ET_SWITCH_FLOAT_TYPES_AND(Long, in_type, ctx, __func__, CTYPE, [&]() {
65+
if (divisor_override.has_value()) {
66+
int64_t divisor = divisor_override.value();
67+
// If divisor_override is specified, then we don't need to use `count` in
68+
// the calculation. Simply sum x / divisor to get the output.
69+
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
70+
[](const CTYPE in_val,
71+
int64_t in_idx,
72+
CTYPE accum,
73+
int64_t accum_idx) {
74+
// Average pooling does not track indexes, so return 0 for accum_idx
75+
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
76+
},
77+
[divisor](const int64_t count, const CTYPE accum) {
78+
return accum / static_cast<CTYPE>(divisor);
79+
},
80+
count_include_pad,
81+
in,
82+
kernel_size,
83+
stride,
84+
padding,
85+
{},
86+
out);
87+
} else {
88+
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
89+
[](const CTYPE in_val,
90+
int64_t in_idx,
91+
CTYPE accum,
92+
int64_t accum_idx) {
93+
// Average pooling does not track indexes, so return 0 for accum_idx
94+
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
95+
},
96+
[](const int64_t count, const CTYPE accum) {
97+
return accum / static_cast<CTYPE>(count);
98+
},
99+
count_include_pad,
100+
in,
101+
kernel_size,
102+
stride,
103+
padding,
104+
{},
105+
out);
106+
}
107+
});
108+
109+
return out;
110+
}
111+
112+
} // namespace native
113+
} // namespace executor
114+
} // namespace torch

kernels/portable/cpu/op_max_pool2d_with_indices.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <cstring>
1010

1111
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12-
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
1312
#include <executorch/runtime/kernel/kernel_includes.h>
1413

1514
namespace torch {
@@ -55,7 +54,7 @@ std::tuple<Tensor&, Tensor&> max_pool2d_with_indices_out(
5554
ctx,
5655
output_size_is_valid({output_sizes, output_ndim}),
5756
InvalidArgument,
58-
out);
57+
ret_val);
5958

6059
ET_KERNEL_CHECK(
6160
ctx,
@@ -71,13 +70,19 @@ std::tuple<Tensor&, Tensor&> max_pool2d_with_indices_out(
7170

7271
ScalarType in_type = in.scalar_type();
7372
ET_SWITCH_REAL_TYPES(in_type, ctx, __func__, CTYPE, [&]() {
74-
apply_kernel_2d_reduce_fn<CTYPE>(
75-
[](const CTYPE in_val, int64_t in_idx, CTYPE accum, int64_t accum_idx) {
73+
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
74+
[](const CTYPE in_val,
75+
const int64_t in_idx,
76+
const CTYPE accum,
77+
const int64_t accum_idx) {
7678
if (in_val > accum) {
7779
return std::tuple<CTYPE, int64_t>(in_val, in_idx);
7880
}
7981
return std::tuple<CTYPE, int64_t>(accum, accum_idx);
8082
},
83+
// Max pooling does not need to post-process the accumulated output
84+
[](const int64_t count, const CTYPE accum) { return accum; },
85+
/*include_pad=*/false,
8186
in,
8287
kernel_size,
8388
stride,

kernels/portable/cpu/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ _ATEN_OPS = (
119119
"//executorch/kernels/portable/cpu/pattern:pattern",
120120
],
121121
),
122+
op_target(
123+
name = "op_avg_pool2d",
124+
deps = [
125+
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
126+
],
127+
),
122128
op_target(
123129
name = "op_bitwise_and",
124130
deps = [

kernels/portable/cpu/util/kernel_ops_util.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,58 @@ void calculate_kernel_output_sizes(
192192
}
193193
}
194194

195+
bool check_avg_pool2d_args(
196+
const Tensor& in,
197+
const IntArrayRef kernel_size,
198+
const IntArrayRef stride,
199+
const IntArrayRef padding,
200+
const bool ceil_mode,
201+
const bool count_include_pad,
202+
const exec_aten::optional<int64_t>& divisor_override,
203+
const Tensor& out) {
204+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
205+
206+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
207+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
208+
209+
ET_LOG_AND_RETURN_IF_FALSE(kernel_size_is_valid(kernel_size, 2));
210+
if (stride.size() > 0) {
211+
ET_LOG_AND_RETURN_IF_FALSE(stride_is_valid(kernel_size, 2));
212+
}
213+
ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(padding, kernel_size, 2, true));
214+
215+
if (divisor_override.has_value()) {
216+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
217+
divisor_override.value() > 0,
218+
"divisor_override must be > 0, but found %" PRId64,
219+
divisor_override.value());
220+
}
221+
222+
return true;
223+
}
224+
225+
void get_avg_pool2d_out_target_size(
226+
const Tensor& in,
227+
const IntArrayRef kernel_size,
228+
const IntArrayRef stride,
229+
const IntArrayRef padding,
230+
const bool ceil_mode,
231+
exec_aten::SizesType* const out_sizes,
232+
size_t* const out_ndim) {
233+
*out_ndim = in.dim();
234+
235+
// Batch dim is optional, so in can be either 3 or 4 dim.
236+
if (in.dim() == 4) {
237+
out_sizes[0] = in.size(0);
238+
out_sizes[1] = in.size(1);
239+
} else {
240+
out_sizes[0] = in.size(0);
241+
}
242+
243+
calculate_kernel_output_sizes(
244+
in, kernel_size, stride, padding, {}, out_sizes, ceil_mode);
245+
}
246+
195247
bool check_convolution_args(
196248
const Tensor& in,
197249
const Tensor& weight,

0 commit comments

Comments
 (0)