Skip to content

Commit c15689b

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add portable upsample_nearest2d kernel
Summary: Add a upsample_nearest2d kernel to the portable kernel library. This implementation re-uses some of the inner logic from the ATen implementation (see Upsample.h and UpsampleKernel.cpp), however I have not ported the outer kernel structure as it relies on TensorIterator and runtime allocation. It may be worth re-visiting this in the future, either by looking at pulling in more of the ATen implementation or adding an optimized variant. Differential Revision: D66089829
1 parent 9a884a8 commit c15689b

File tree

11 files changed

+602
-2
lines changed

11 files changed

+602
-2
lines changed

kernels/aten/functions.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,6 @@
407407

408408
- op: upsample_bilinear2d.vec_out
409409

410-
- op: upsample_nearest2d.out
411-
412410
- op: upsample_nearest2d.vec_out
413411

414412
- op: var.correction_out
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/upsample_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
using exec_aten::ArrayRef;
17+
using exec_aten::SizesType;
18+
using exec_aten::optional;
19+
20+
namespace {
21+
template <typename CTYPE>
22+
void upsample_nearest2d_kernel_impl(
23+
const Tensor& in,
24+
const float scale_h,
25+
const float scale_w,
26+
Tensor& out) {
27+
const auto in_data = in.const_data_ptr<CTYPE>();
28+
auto out_data = out.mutable_data_ptr<CTYPE>();
29+
30+
auto in_plane = in_data;
31+
for (auto n = 0; n < out.size(0); n++) {
32+
for (auto c = 0; c < out.size(1); c++) {
33+
for (auto h = 0; h < out.size(2); h++) {
34+
for (auto w = 0; w < out.size(3); w++) {
35+
const auto in_h = nearest_neighbor_compute_source_index(scale_h, h, in.sizes()[2]);
36+
const auto in_w = nearest_neighbor_compute_source_index(scale_w, w, in.sizes()[3]);
37+
38+
*out_data = in_plane[in_h * in.strides()[2] + in_w * in.strides()[3]];
39+
out_data++;
40+
}
41+
}
42+
43+
in_plane += in.strides()[1];
44+
}
45+
}
46+
}
47+
}
48+
49+
Tensor& upsample_nearest2d_vec_out(
50+
KernelRuntimeContext& ctx,
51+
const Tensor& in,
52+
const exec_aten::OptionalArrayRef<int64_t> output_size,
53+
const exec_aten::OptionalArrayRef<double> scale_factors,
54+
Tensor& out) {
55+
56+
// Preconditions (checked in check_..._args):
57+
// In and out tensors have same dtype.
58+
// In and out tensors are rank 4 and have same dim[0] and dim[1].
59+
// In and out tensors are default dim order (NCHW).
60+
ET_KERNEL_CHECK(
61+
ctx,
62+
check_upsample_nearest2d_args(in, output_size, scale_factors, out),
63+
InvalidArgument,
64+
out);
65+
66+
double scale_h, scale_w;
67+
68+
ET_KERNEL_CHECK_MSG(
69+
ctx,
70+
resize_upsample_2d(in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok,
71+
InvalidArgument,
72+
out,
73+
"Failed to resize output tensor");
74+
75+
const auto kernel_scale_h = area_pixel_compute_scale<double>(in.sizes()[2], out.sizes()[2], false, scale_h);
76+
const auto kernel_scale_w = area_pixel_compute_scale<double>(in.sizes()[3], out.sizes()[3], false, scale_w);
77+
78+
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "upsample_nearest2d.out", CTYPE, [&]() {
79+
upsample_nearest2d_kernel_impl<CTYPE>(in, kernel_scale_h, kernel_scale_w, out);
80+
});
81+
82+
return out;
83+
}
84+
85+
}
86+
}
87+
}

kernels/portable/cpu/util/upsample_util.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ bool check_upsample_bilinear2d_args(
4646
return check_upsample_2d_common_args(in, output_size, scale_factors, out);
4747
}
4848

49+
bool check_upsample_nearest2d_args(
50+
const Tensor& in,
51+
const exec_aten::OptionalArrayRef<int64_t>& output_size,
52+
const exec_aten::OptionalArrayRef<double>& scale_factors,
53+
Tensor& out) {
54+
return check_upsample_2d_common_args(in, output_size, scale_factors, out);
55+
}
56+
4957
Error resize_upsample_2d(
5058
const Tensor& in,
5159
const exec_aten::OptionalArrayRef<int64_t>& output_size,

kernels/portable/cpu/util/upsample_util.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ bool check_upsample_bilinear2d_args(
2828
const exec_aten::OptionalArrayRef<double>& scale_factors,
2929
Tensor& out);
3030

31+
bool check_upsample_nearest2d_args(
32+
const Tensor& in,
33+
const exec_aten::OptionalArrayRef<int64_t>& output_size,
34+
const exec_aten::OptionalArrayRef<double>& scale_factors,
35+
Tensor& out);
36+
3137
Error resize_upsample_2d(
3238
const Tensor& in,
3339
const exec_aten::OptionalArrayRef<int64_t>& output_size,
@@ -127,5 +133,17 @@ inline void compute_source_index_and_lambda(
127133
}
128134
}
129135

136+
// Ported from aten/src/ATen/native/UpSample.h
137+
inline int64_t nearest_neighbor_compute_source_index(
138+
const float scale,
139+
int64_t dst_index,
140+
int64_t input_size) {
141+
// Index computation matching OpenCV INTER_NEAREST
142+
// which is buggy and kept for BC
143+
const int64_t src_index =
144+
std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
145+
return src_index;
146+
}
147+
130148
} // namespace executor
131149
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,11 @@
922922
- arg_meta: null
923923
kernel_name: torch::executor::upsample_bilinear2d_vec_out
924924

925+
- op: upsample_nearest2d.vec_out
926+
kernels:
927+
- arg_meta: null
928+
kernel_name: torch::executor::upsample_nearest2d_vec_out
929+
925930
- op: var.correction_out
926931
kernels:
927932
- arg_meta: null

kernels/portable/test/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ runtime.cxx_library(
2121
deps = [
2222
"//executorch/extension/aten_util:aten_bridge",
2323
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
24+
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",
2425
"//executorch/runtime/core/exec_aten:lib",
2526
],
2627
external_deps = [
@@ -40,3 +41,16 @@ python_unittest(
4041
"//caffe2:torch",
4142
],
4243
)
44+
45+
python_unittest(
46+
name = "op_upsample_nearest2d_test",
47+
srcs = [
48+
"op_upsample_nearest2d_test.py",
49+
],
50+
preload_deps = [
51+
":aot_ops_test_lib",
52+
],
53+
deps = [
54+
"//caffe2:torch",
55+
],
56+
)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import itertools
10+
import unittest
11+
12+
from typing import Optional, Sequence
13+
14+
import torch
15+
16+
17+
class UpsampleNearest2dTest(unittest.TestCase):
18+
def run_upsample_test(
19+
self,
20+
inp: torch.Tensor,
21+
output_size: Optional[Sequence[int]] = None,
22+
scale_factors: Optional[Sequence[float]] = None,
23+
atol=1e-7,
24+
) -> None:
25+
aten_result = torch.nn.functional.interpolate(
26+
inp,
27+
size=output_size,
28+
mode="nearest",
29+
scale_factor=scale_factors,
30+
)
31+
et_result = torch.zeros_like(aten_result)
32+
et_result = torch.ops.et_test.upsample_nearest2d(
33+
inp,
34+
output_size=output_size,
35+
scale_factors=scale_factors,
36+
out=et_result,
37+
)
38+
self.assertTrue(
39+
torch.allclose(et_result, aten_result, atol=atol),
40+
msg=f"ET: {et_result} \n ATen: {aten_result} \n Error: {et_result.to(torch.float) - aten_result.to(torch.float)}",
41+
)
42+
43+
def test_upsample_nearest2d_aten_parity_f32(self):
44+
N = [1, 2]
45+
C = [1, 3]
46+
H = [1, 3, 50, 1001]
47+
W = [1, 2, 62, 1237]
48+
OUT_H = [5, 21]
49+
OUT_W = [7, 31]
50+
51+
for n, c, h, w, out_h, out_w in itertools.product(
52+
N, C, H, W, OUT_H, OUT_W
53+
):
54+
input = torch.randn(n, c, h, w)
55+
self.run_upsample_test(
56+
input, output_size=(out_h, out_w)
57+
)
58+
self.run_upsample_test(
59+
input, scale_factors=(out_h / h, out_w / w)
60+
)
61+
62+
def test_upsample_nearest2d_aten_parity_u8(self):
63+
N = [1, 2]
64+
C = [1, 3]
65+
H = [1, 3, 50, 1001]
66+
W = [1, 2, 62, 1237]
67+
OUT_H = [5, 21]
68+
OUT_W = [7, 31]
69+
70+
for n, c, h, w, out_h, out_w in itertools.product(
71+
N, C, H, W, OUT_H, OUT_W
72+
):
73+
input = torch.randint(0, 255, (n, c, h, w), dtype=torch.uint8)
74+
self.run_upsample_test(
75+
input, output_size=(out_h, out_w), atol=1
76+
)
77+
self.run_upsample_test(
78+
input,
79+
scale_factors=(out_h / h, out_w / w),
80+
atol=2,
81+
)

kernels/portable/test/register_ops_aot_for_test.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,37 @@ Tensor& upsample_bilinear2d_vec_out_no_context(
4747

4848
return ret;
4949
}
50+
51+
Tensor& upsample_nearest2d_vec_out(
52+
KernelRuntimeContext& ctx,
53+
const Tensor& in,
54+
const exec_aten::OptionalArrayRef<int64_t> output_size,
55+
const exec_aten::OptionalArrayRef<double> scale_factors,
56+
Tensor& out);
57+
58+
Tensor& upsample_nearest2d_vec_out_no_context(
59+
const Tensor& in,
60+
const exec_aten::OptionalArrayRef<int64_t> output_size,
61+
const exec_aten::OptionalArrayRef<double> scale_factors,
62+
Tensor& out) {
63+
KernelRuntimeContext ctx;
64+
auto& ret = upsample_nearest2d_vec_out(
65+
ctx, in, output_size, scale_factors, out);
66+
67+
if (ctx.failure_state() != Error::Ok) {
68+
throw std::runtime_error(std::string("Kernel failed with error: ") + std::to_string((int) ctx.failure_state()));
69+
}
70+
71+
return ret;
72+
}
5073
// NOLINTEND(facebook-hte-ConstantArgumentPassByValue,
5174
// facebook-hte-ParameterMightThrowOnCopy)
5275

5376
TORCH_LIBRARY(et_test, m) {
5477
m.def(
5578
"upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)",
5679
WRAP_TO_ATEN(upsample_bilinear2d_vec_out_no_context, 4));
80+
m.def("upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(upsample_nearest2d_vec_out_no_context, 3));
5781
}
5882

5983
} // namespace native

0 commit comments

Comments
 (0)