Skip to content

Commit acea876

Browse files
committed
Add portable upsample_nearest2d kernel
Summary: Add a upsample_nearest2d kernel to the portable kernel library. This implementation re-uses some of the inner logic from the ATen implementation (see Upsample.h and UpsampleKernel.cpp), however I have not ported the outer kernel structure as it relies on TensorIterator and runtime allocation. It may be worth re-visiting this in the future, either by looking at pulling in more of the ATen implementation or adding an optimized variant. Differential Revision: D66089829
1 parent 9a884a8 commit acea876

File tree

11 files changed

+652
-2
lines changed

11 files changed

+652
-2
lines changed

kernels/aten/functions.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,6 @@
407407

408408
- op: upsample_bilinear2d.vec_out
409409

410-
- op: upsample_nearest2d.out
411-
412410
- op: upsample_nearest2d.vec_out
413411

414412
- op: var.correction_out
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
using exec_aten::ArrayRef;
17+
using exec_aten::optional;
18+
using exec_aten::SizesType;
19+
20+
namespace {
21+
template <typename CTYPE>
22+
void upsample_nearest2d_kernel_impl(
23+
const Tensor& in,
24+
const float scale_h,
25+
const float scale_w,
26+
Tensor& out) {
27+
const auto in_data = in.const_data_ptr<CTYPE>();
28+
auto out_data = out.mutable_data_ptr<CTYPE>();
29+
30+
auto in_plane = in_data;
31+
for (auto n = 0; n < out.size(0); n++) {
32+
for (auto c = 0; c < out.size(1); c++) {
33+
for (auto h = 0; h < out.size(2); h++) {
34+
for (auto w = 0; w < out.size(3); w++) {
35+
const auto in_h =
36+
nearest_neighbor_compute_source_index(scale_h, h, in.sizes()[2]);
37+
const auto in_w =
38+
nearest_neighbor_compute_source_index(scale_w, w, in.sizes()[3]);
39+
40+
*out_data = in_plane[in_h * in.strides()[2] + in_w * in.strides()[3]];
41+
out_data++;
42+
}
43+
}
44+
45+
in_plane += in.strides()[1];
46+
}
47+
}
48+
}
49+
} // namespace
50+
51+
Tensor& upsample_nearest2d_vec_out(
52+
KernelRuntimeContext& ctx,
53+
const Tensor& in,
54+
const exec_aten::OptionalArrayRef<int64_t> output_size,
55+
const exec_aten::OptionalArrayRef<double> scale_factors,
56+
Tensor& out) {
57+
// Preconditions (checked in check_..._args):
58+
// In and out tensors have same dtype.
59+
// In and out tensors are rank 4 and have same dim[0] and dim[1].
60+
// In and out tensors are default dim order (NCHW).
61+
ET_KERNEL_CHECK(
62+
ctx,
63+
check_upsample_nearest2d_args(in, output_size, scale_factors, out),
64+
InvalidArgument,
65+
out);
66+
67+
double scale_h, scale_w;
68+
69+
ET_KERNEL_CHECK_MSG(
70+
ctx,
71+
resize_upsample_2d(
72+
in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok,
73+
InvalidArgument,
74+
out,
75+
"Failed to resize output tensor");
76+
77+
const auto kernel_scale_h = area_pixel_compute_scale<double>(
78+
in.sizes()[2], out.sizes()[2], false, scale_h);
79+
const auto kernel_scale_w = area_pixel_compute_scale<double>(
80+
in.sizes()[3], out.sizes()[3], false, scale_w);
81+
82+
ET_SWITCH_REAL_TYPES(
83+
in.scalar_type(), ctx, "upsample_nearest2d.out", CTYPE, [&]() {
84+
upsample_nearest2d_kernel_impl<CTYPE>(
85+
in, kernel_scale_h, kernel_scale_w, out);
86+
});
87+
88+
return out;
89+
}
90+
91+
} // namespace native
92+
} // namespace executor
93+
} // namespace torch

kernels/portable/cpu/util/upsample_util.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ bool check_upsample_bilinear2d_args(
4646
return check_upsample_2d_common_args(in, output_size, scale_factors, out);
4747
}
4848

49+
bool check_upsample_nearest2d_args(
50+
const Tensor& in,
51+
const exec_aten::OptionalArrayRef<int64_t>& output_size,
52+
const exec_aten::OptionalArrayRef<double>& scale_factors,
53+
Tensor& out) {
54+
return check_upsample_2d_common_args(in, output_size, scale_factors, out);
55+
}
56+
4957
Error resize_upsample_2d(
5058
const Tensor& in,
5159
const exec_aten::OptionalArrayRef<int64_t>& output_size,

kernels/portable/cpu/util/upsample_util.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ bool check_upsample_bilinear2d_args(
2828
const exec_aten::OptionalArrayRef<double>& scale_factors,
2929
Tensor& out);
3030

31+
bool check_upsample_nearest2d_args(
32+
const Tensor& in,
33+
const exec_aten::OptionalArrayRef<int64_t>& output_size,
34+
const exec_aten::OptionalArrayRef<double>& scale_factors,
35+
Tensor& out);
36+
3137
Error resize_upsample_2d(
3238
const Tensor& in,
3339
const exec_aten::OptionalArrayRef<int64_t>& output_size,
@@ -127,5 +133,17 @@ inline void compute_source_index_and_lambda(
127133
}
128134
}
129135

136+
// Ported from aten/src/ATen/native/UpSample.h
137+
inline int64_t nearest_neighbor_compute_source_index(
138+
const float scale,
139+
int64_t dst_index,
140+
int64_t input_size) {
141+
// Index computation matching OpenCV INTER_NEAREST
142+
// which is buggy and kept for BC
143+
const int64_t src_index =
144+
std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
145+
return src_index;
146+
}
147+
130148
} // namespace executor
131149
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,11 @@
922922
- arg_meta: null
923923
kernel_name: torch::executor::upsample_bilinear2d_vec_out
924924

925+
- op: upsample_nearest2d.vec_out
926+
kernels:
927+
- arg_meta: null
928+
kernel_name: torch::executor::upsample_nearest2d_vec_out
929+
925930
- op: var.correction_out
926931
kernels:
927932
- arg_meta: null

kernels/portable/test/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ runtime.cxx_library(
2121
deps = [
2222
"//executorch/extension/aten_util:aten_bridge",
2323
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
24+
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",
2425
"//executorch/runtime/core/exec_aten:lib",
2526
],
2627
external_deps = [
@@ -40,3 +41,16 @@ python_unittest(
4041
"//caffe2:torch",
4142
],
4243
)
44+
45+
python_unittest(
46+
name = "op_upsample_nearest2d_test",
47+
srcs = [
48+
"op_upsample_nearest2d_test.py",
49+
],
50+
preload_deps = [
51+
":aot_ops_test_lib",
52+
],
53+
deps = [
54+
"//caffe2:torch",
55+
],
56+
)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import itertools
10+
import unittest
11+
12+
from typing import Optional, Sequence
13+
14+
import torch
15+
16+
17+
class UpsampleNearest2dTest(unittest.TestCase):
18+
def run_upsample_test(
19+
self,
20+
inp: torch.Tensor,
21+
output_size: Optional[Sequence[int]] = None,
22+
scale_factors: Optional[Sequence[float]] = None,
23+
atol=1e-7,
24+
) -> None:
25+
aten_result = torch.nn.functional.interpolate(
26+
inp,
27+
size=output_size,
28+
mode="nearest",
29+
scale_factor=scale_factors,
30+
)
31+
et_result = torch.zeros_like(aten_result)
32+
et_result = torch.ops.et_test.upsample_nearest2d(
33+
inp,
34+
output_size=output_size,
35+
scale_factors=scale_factors,
36+
out=et_result,
37+
)
38+
self.assertTrue(
39+
torch.allclose(et_result, aten_result, atol=atol),
40+
msg=f"ET: {et_result} \n ATen: {aten_result} \n Error: {et_result.to(torch.float) - aten_result.to(torch.float)}",
41+
)
42+
43+
def test_upsample_nearest2d_aten_parity_f32(self):
44+
N = [1, 2]
45+
C = [1, 3]
46+
H = [1, 3, 50, 1001]
47+
W = [1, 2, 62, 1237]
48+
OUT_H = [5, 21]
49+
OUT_W = [7, 31]
50+
51+
for n, c, h, w, out_h, out_w in itertools.product(N, C, H, W, OUT_H, OUT_W):
52+
input = torch.randn(n, c, h, w)
53+
self.run_upsample_test(input, output_size=(out_h, out_w))
54+
self.run_upsample_test(input, scale_factors=(out_h / h, out_w / w))
55+
56+
def test_upsample_nearest2d_aten_parity_u8(self):
57+
N = [1, 2]
58+
C = [1, 3]
59+
H = [1, 3, 50, 1001]
60+
W = [1, 2, 62, 1237]
61+
OUT_H = [5, 21]
62+
OUT_W = [7, 31]
63+
64+
for n, c, h, w, out_h, out_w in itertools.product(N, C, H, W, OUT_H, OUT_W):
65+
input = torch.randint(0, 255, (n, c, h, w), dtype=torch.uint8)
66+
self.run_upsample_test(input, output_size=(out_h, out_w), atol=1)
67+
self.run_upsample_test(
68+
input,
69+
scale_factors=(out_h / h, out_w / w),
70+
atol=2,
71+
)

kernels/portable/test/register_ops_aot_for_test.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,41 @@ Tensor& upsample_bilinear2d_vec_out_no_context(
4747

4848
return ret;
4949
}
50+
51+
Tensor& upsample_nearest2d_vec_out(
52+
KernelRuntimeContext& ctx,
53+
const Tensor& in,
54+
const exec_aten::OptionalArrayRef<int64_t> output_size,
55+
const exec_aten::OptionalArrayRef<double> scale_factors,
56+
Tensor& out);
57+
58+
Tensor& upsample_nearest2d_vec_out_no_context(
59+
const Tensor& in,
60+
const exec_aten::OptionalArrayRef<int64_t> output_size,
61+
const exec_aten::OptionalArrayRef<double> scale_factors,
62+
Tensor& out) {
63+
KernelRuntimeContext ctx;
64+
auto& ret =
65+
upsample_nearest2d_vec_out(ctx, in, output_size, scale_factors, out);
66+
67+
if (ctx.failure_state() != Error::Ok) {
68+
throw std::runtime_error(
69+
std::string("Kernel failed with error: ") +
70+
std::to_string((int)ctx.failure_state()));
71+
}
72+
73+
return ret;
74+
}
5075
// NOLINTEND(facebook-hte-ConstantArgumentPassByValue,
5176
// facebook-hte-ParameterMightThrowOnCopy)
5277

5378
TORCH_LIBRARY(et_test, m) {
5479
m.def(
5580
"upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)",
5681
WRAP_TO_ATEN(upsample_bilinear2d_vec_out_no_context, 4));
82+
m.def(
83+
"upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)",
84+
WRAP_TO_ATEN(upsample_nearest2d_vec_out_no_context, 3));
5785
}
5886

5987
} // namespace native

0 commit comments

Comments
 (0)