Skip to content

Add portable upsample_nearest2d kernel #7464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions kernels/aten/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,6 @@

- op: upsample_bilinear2d.vec_out

- op: upsample_nearest2d.out

- op: upsample_nearest2d.vec_out

- op: var.correction_out
Expand Down
93 changes: 93 additions & 0 deletions kernels/portable/cpu/op_upsample_nearest2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/upsample_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {

using exec_aten::ArrayRef;
using exec_aten::optional;
using exec_aten::SizesType;

namespace {
template <typename CTYPE>
void upsample_nearest2d_kernel_impl(
const Tensor& in,
const float scale_h,
const float scale_w,
Tensor& out) {
const auto in_data = in.const_data_ptr<CTYPE>();
auto out_data = out.mutable_data_ptr<CTYPE>();

auto in_plane = in_data;
for (auto n = 0; n < out.size(0); n++) {
for (auto c = 0; c < out.size(1); c++) {
for (auto h = 0; h < out.size(2); h++) {
for (auto w = 0; w < out.size(3); w++) {
const auto in_h =
nearest_neighbor_compute_source_index(scale_h, h, in.sizes()[2]);
const auto in_w =
nearest_neighbor_compute_source_index(scale_w, w, in.sizes()[3]);

*out_data = in_plane[in_h * in.strides()[2] + in_w * in.strides()[3]];
out_data++;
}
}

in_plane += in.strides()[1];
}
}
}
} // namespace

Tensor& upsample_nearest2d_vec_out(
KernelRuntimeContext& ctx,
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t> output_size,
const exec_aten::OptionalArrayRef<double> scale_factors,
Tensor& out) {
// Preconditions (checked in check_..._args):
// In and out tensors have same dtype.
// In and out tensors are rank 4 and have same dim[0] and dim[1].
// In and out tensors are default dim order (NCHW).
ET_KERNEL_CHECK(
ctx,
check_upsample_nearest2d_args(in, output_size, scale_factors, out),
InvalidArgument,
out);

double scale_h, scale_w;

ET_KERNEL_CHECK_MSG(
ctx,
resize_upsample_2d(
in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor");

const auto kernel_scale_h = area_pixel_compute_scale<double>(
in.sizes()[2], out.sizes()[2], false, scale_h);
const auto kernel_scale_w = area_pixel_compute_scale<double>(
in.sizes()[3], out.sizes()[3], false, scale_w);

ET_SWITCH_REAL_TYPES(
in.scalar_type(), ctx, "upsample_nearest2d.out", CTYPE, [&]() {
upsample_nearest2d_kernel_impl<CTYPE>(
in, kernel_scale_h, kernel_scale_w, out);
});

return out;
}

} // namespace native
} // namespace executor
} // namespace torch
8 changes: 8 additions & 0 deletions kernels/portable/cpu/util/upsample_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ bool check_upsample_bilinear2d_args(
return check_upsample_2d_common_args(in, output_size, scale_factors, out);
}

bool check_upsample_nearest2d_args(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t>& output_size,
const exec_aten::OptionalArrayRef<double>& scale_factors,
Tensor& out) {
return check_upsample_2d_common_args(in, output_size, scale_factors, out);
}

Error resize_upsample_2d(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t>& output_size,
Expand Down
18 changes: 18 additions & 0 deletions kernels/portable/cpu/util/upsample_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ bool check_upsample_bilinear2d_args(
const exec_aten::OptionalArrayRef<double>& scale_factors,
Tensor& out);

bool check_upsample_nearest2d_args(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t>& output_size,
const exec_aten::OptionalArrayRef<double>& scale_factors,
Tensor& out);

Error resize_upsample_2d(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t>& output_size,
Expand Down Expand Up @@ -127,5 +133,17 @@ inline void compute_source_index_and_lambda(
}
}

// Ported from aten/src/ATen/native/UpSample.h
inline int64_t nearest_neighbor_compute_source_index(
const float scale,
int64_t dst_index,
int64_t input_size) {
// Index computation matching OpenCV INTER_NEAREST
// which is buggy and kept for BC
const int64_t src_index =
std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
return src_index;
}

} // namespace executor
} // namespace torch
5 changes: 5 additions & 0 deletions kernels/portable/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,11 @@
- arg_meta: null
kernel_name: torch::executor::upsample_bilinear2d_vec_out

- op: upsample_nearest2d.vec_out
kernels:
- arg_meta: null
kernel_name: torch::executor::upsample_nearest2d_vec_out

- op: var.correction_out
kernels:
- arg_meta: null
Expand Down
14 changes: 14 additions & 0 deletions kernels/portable/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ runtime.cxx_library(
deps = [
"//executorch/extension/aten_util:aten_bridge",
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",
"//executorch/runtime/core/exec_aten:lib",
],
external_deps = [
Expand All @@ -40,3 +41,16 @@ python_unittest(
"//caffe2:torch",
],
)

python_unittest(
name = "op_upsample_nearest2d_test",
srcs = [
"op_upsample_nearest2d_test.py",
],
preload_deps = [
":aot_ops_test_lib",
],
deps = [
"//caffe2:torch",
],
)
71 changes: 71 additions & 0 deletions kernels/portable/test/op_upsample_nearest2d_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import itertools
import unittest

from typing import Optional, Sequence

import torch


class UpsampleNearest2dTest(unittest.TestCase):
def run_upsample_test(
self,
inp: torch.Tensor,
output_size: Optional[Sequence[int]] = None,
scale_factors: Optional[Sequence[float]] = None,
atol=1e-7,
) -> None:
aten_result = torch.nn.functional.interpolate(
inp,
size=output_size,
mode="nearest",
scale_factor=scale_factors,
)
et_result = torch.zeros_like(aten_result)
et_result = torch.ops.et_test.upsample_nearest2d(
inp,
output_size=output_size,
scale_factors=scale_factors,
out=et_result,
)
self.assertTrue(
torch.allclose(et_result, aten_result, atol=atol),
msg=f"ET: {et_result} \n ATen: {aten_result} \n Error: {et_result.to(torch.float) - aten_result.to(torch.float)}",
)

def test_upsample_nearest2d_aten_parity_f32(self):
N = [1, 2]
C = [1, 3]
H = [1, 3, 50, 1001]
W = [1, 2, 62, 1237]
OUT_H = [5, 21]
OUT_W = [7, 31]

for n, c, h, w, out_h, out_w in itertools.product(N, C, H, W, OUT_H, OUT_W):
input = torch.randn(n, c, h, w)
self.run_upsample_test(input, output_size=(out_h, out_w))
self.run_upsample_test(input, scale_factors=(out_h / h, out_w / w))

def test_upsample_nearest2d_aten_parity_u8(self):
N = [1, 2]
C = [1, 3]
H = [1, 3, 50, 1001]
W = [1, 2, 62, 1237]
OUT_H = [5, 21]
OUT_W = [7, 31]

for n, c, h, w, out_h, out_w in itertools.product(N, C, H, W, OUT_H, OUT_W):
input = torch.randint(0, 255, (n, c, h, w), dtype=torch.uint8)
self.run_upsample_test(input, output_size=(out_h, out_w), atol=1)
self.run_upsample_test(
input,
scale_factors=(out_h / h, out_w / w),
atol=2,
)
28 changes: 28 additions & 0 deletions kernels/portable/test/register_ops_aot_for_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,41 @@ Tensor& upsample_bilinear2d_vec_out_no_context(

return ret;
}

Tensor& upsample_nearest2d_vec_out(
KernelRuntimeContext& ctx,
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t> output_size,
const exec_aten::OptionalArrayRef<double> scale_factors,
Tensor& out);

Tensor& upsample_nearest2d_vec_out_no_context(
const Tensor& in,
const exec_aten::OptionalArrayRef<int64_t> output_size,
const exec_aten::OptionalArrayRef<double> scale_factors,
Tensor& out) {
KernelRuntimeContext ctx;
auto& ret =
upsample_nearest2d_vec_out(ctx, in, output_size, scale_factors, out);

if (ctx.failure_state() != Error::Ok) {
throw std::runtime_error(
std::string("Kernel failed with error: ") +
std::to_string((int)ctx.failure_state()));
}

return ret;
}
// NOLINTEND(facebook-hte-ConstantArgumentPassByValue,
// facebook-hte-ParameterMightThrowOnCopy)

TORCH_LIBRARY(et_test, m) {
m.def(
"upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)",
WRAP_TO_ATEN(upsample_bilinear2d_vec_out_no_context, 4));
m.def(
"upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)",
WRAP_TO_ATEN(upsample_nearest2d_vec_out_no_context, 3));
}

} // namespace native
Expand Down
Loading