Skip to content

Arm backend: Add upsample_bilinear2d op #10349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def is_node_supported(
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_bilinear2d.vec,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.var.correction,
exir_ops.edge.aten.var.dim,
Expand Down Expand Up @@ -365,6 +366,7 @@ def is_node_supported(
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_bilinear2d.vec,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.gelu.default,
):
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
op_to_copy,
op_to_dim_order_copy,
op_transpose,
op_upsample_bilinear2d,
op_upsample_nearest2d,
op_view,
op_where,
Expand Down
100 changes: 100 additions & 0 deletions backends/arm/operators/op_upsample_bilinear2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale
from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape
from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore


@register_node_visitor
class UpsampleBilinear2dVisitor_0_80(NodeVisitor):
target = "aten.upsample_bilinear2d.vec"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
assert (
inputs[0].shape is not None and output.shape is not None
), "Only static shapes are supported"

input_dtype = inputs[0].dtype

# tosa_shape output is NHWC, take HW
input_size_yx = torch.tensor(
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
)
# Ignore scale and size parameters, directly use the output size as
# we only support static shapes currently
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])

scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
)

def in_int16_range(x):
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)

assert in_int16_range(scale_n_yx)
assert in_int16_range(scale_d_yx)
assert in_int16_range(border_yx)

attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(
scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]],
offset=offset_yx.tolist(),
border=border_yx.tolist(),
mode=ResizeMode.BILINEAR,
)

if input_dtype == output.dtype == ts.DType.FP32:
tosa_graph.addOperator(
ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr
)
return
elif input_dtype == output.dtype == ts.DType.INT8:
intermediate = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
)

tosa_graph.addOperator(
ts.TosaOp.Op().RESIZE, [inputs[0].name], [intermediate.name], attr
)

final_output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1]))

build_rescale(
tosa_fb=tosa_graph,
scale=[final_output_scale],
input_node=intermediate,
output_name=output.name,
output_type=ts.DType.INT8,
output_shape=output.shape,
input_zp=0,
output_zp=0,
is_double_round=False,
)
else:
raise ValueError(
"Input/output dtype not in {float32, int8}: {input_dtype=} {output.dtype=}"
)
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _match_pattern(
torch.ops.aten.flip.default,
torch.ops.aten.chunk.default,
torch.ops.aten.contiguous.default,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
torch.ops.aten.pad.default,
torch.ops.aten.amax.default,
Expand Down
247 changes: 247 additions & 0 deletions backends/arm/test/ops/test_upsample_bilinear2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple

import torch
from executorch.backends.arm.test import common

from executorch.backends.arm.test.tester.test_pipeline import (
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)

aten_op = "torch.ops.aten.upsample_bilinear2d.vec"
input_t1 = Tuple[torch.Tensor] # Input x

test_data_suite_tosa = {
# (test_name, test_data, size, scale_factor, compare_outputs)
"rand_double_scale": (torch.rand(2, 4, 8, 3), None, 2.0, True),
"rand_double_scale_one_dim": (torch.rand(2, 4, 8, 3), None, (1.0, 2.0), True),
"rand_double_size": (torch.rand(2, 4, 8, 3), (16, 6), None, True),
"rand_one_double_scale": (torch.rand(2, 4, 1, 1), None, 2.0, True),
"rand_one_double_size": (torch.rand(2, 4, 1, 1), (2, 2), None, True),
"rand_one_same_scale": (torch.rand(2, 4, 1, 1), None, 1.0, True),
"rand_one_same_size": (torch.rand(2, 4, 1, 1), (1, 1), None, True),
# Can't compare outputs as the rounding when selecting the nearest pixel is
# different between PyTorch and TOSA. Just check the legalization went well.
# TODO Improve the test infrastructure to support more in depth verification
# of the TOSA legalization results.
"rand_half_scale": (torch.rand(2, 4, 8, 6), None, 0.5, False),
"rand_half_size": (torch.rand(2, 4, 8, 6), (4, 3), None, False),
"rand_one_and_half_scale": (torch.rand(2, 4, 8, 3), None, 1.5, False),
"rand_one_and_half_size": (torch.rand(2, 4, 8, 3), (12, 4), None, False),
# Use randn for a bunch of tests to get random numbers from the
# normal distribution where negative is also a possibilty
"randn_double_scale_negative": (torch.randn(2, 4, 8, 3), None, 2.0, True),
"randn_double_scale_one_dim_negative": (
torch.randn(2, 4, 8, 3),
None,
(1.0, 2.0),
True,
),
"randn_double_size_negative": (torch.randn(2, 4, 8, 3), (16, 6), None, True),
"randn_one_double_scale_negative": (torch.randn(2, 4, 1, 1), None, 2.0, True),
"randn_one_double_size_negative": (torch.randn(2, 4, 1, 1), (2, 2), None, True),
"randn_one_same_scale_negative": (torch.randn(2, 4, 1, 1), None, 1.0, True),
"randn_one_same_size_negative": (torch.randn(2, 4, 1, 1), (1, 1), None, True),
}

test_data_suite_Uxx = {
"rand_half_scale": (torch.rand(2, 4, 8, 6), None, 0.5, False),
"rand_half_size": (torch.rand(2, 4, 8, 6), (4, 3), None, False),
"rand_one_and_half_scale": (torch.rand(2, 4, 8, 3), None, 1.5, False),
"rand_one_and_half_size": (torch.rand(2, 4, 8, 3), (12, 4), None, False),
}


class UpsamplingBilinear2d(torch.nn.Module):
def __init__(
self,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
):
super().__init__()
self.upsample = torch.nn.UpsamplingBilinear2d( # noqa: TOR101
size=size, scale_factor=scale_factor
)

def forward(self, x):
return self.upsample(x)


class Upsample(torch.nn.Module):
def __init__(
self,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
):
super().__init__()
self.upsample = torch.nn.Upsample(
size=size, scale_factor=scale_factor, mode="bilinear", align_corners=True
)

def forward(self, x):
return self.upsample(x)


class Interpolate(torch.nn.Module):
def __init__(
self,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
):
super().__init__()
self.upsample = lambda x: torch.nn.functional.interpolate(
x, size=size, scale_factor=scale_factor, mode="bilinear", align_corners=True
)

def forward(self, x):
return self.upsample(x)


@common.parametrize("test_data", test_data_suite_tosa)
def test_upsample_bilinear2d_vec_tosa_MI_UpsamplingBilinear2d(
test_data: torch.Tensor,
):
test_data, size, scale_factor, compare_outputs = test_data

pipeline = TosaPipelineMI[input_t1](
UpsamplingBilinear2d(size, scale_factor),
(test_data,),
aten_op,
exir_op=[],
)
if not compare_outputs:
pipeline.pop_stage(-1)
pipeline.run()


@common.parametrize("test_data", test_data_suite_tosa)
def test_upsample_bilinear2d_vec_tosa_MI_Upsample(
test_data: torch.Tensor,
):
test_data, size, scale_factor, compare_outputs = test_data

pipeline = TosaPipelineMI[input_t1](
Upsample(size, scale_factor),
(test_data,),
aten_op,
exir_op=[],
)
if not compare_outputs:
pipeline.pop_stage(-1)

pipeline.run()


@common.parametrize("test_data", test_data_suite_tosa)
def test_upsample_bilinear2d_vec_tosa_MI_Interpolate(
test_data: torch.Tensor,
):
test_data, size, scale_factor, compare_outputs = test_data

pipeline = TosaPipelineMI[input_t1](
Interpolate(size, scale_factor),
(test_data,),
aten_op,
exir_op=[],
)
if not compare_outputs:
pipeline.pop_stage(-1)
pipeline.run()


@common.parametrize("test_data", test_data_suite_tosa)
def test_upsample_bilinear2d_vec_tosa_BI_intropolate(
test_data: torch.Tensor,
):
test_data, size, scale_factor, compare_outputs = test_data

pipeline = TosaPipelineBI[input_t1](
UpsamplingBilinear2d(size, scale_factor),
(test_data,),
aten_op,
exir_op=[],
)
if not compare_outputs:
pipeline.pop_stage(-1)
pipeline.run()


@common.parametrize("test_data", test_data_suite_tosa)
def test_upsample_bilinear2d_vec_tosa_BI_Upsample(
test_data: torch.Tensor,
):
test_data, size, scale_factor, compare_outputs = test_data

pipeline = TosaPipelineBI[input_t1](
Upsample(size, scale_factor),
(test_data,),
aten_op,
exir_op=[],
)
if not compare_outputs:
pipeline.pop_stage(-1)
pipeline.run()


@common.parametrize("test_data", test_data_suite_Uxx)
@common.XfailIfNoCorstone320
def test_upsample_bilinear2d_vec_U85_BI_Upsample(test_data: input_t1):
test_data, size, scale_factor, compare_outputs = test_data

pipeline = EthosU85PipelineBI[input_t1](
Upsample(size, scale_factor),
(test_data,),
aten_op,
run_on_fvp=True,
qtol=1,
use_to_edge_transform_and_lower=True,
)
if not compare_outputs:
pipeline.pop_stage(-1)
pipeline.run()


@common.parametrize("test_data", test_data_suite_Uxx)
@common.XfailIfNoCorstone320
def test_upsample_bilinear2d_vec_U85_BI_Interpolate(
test_data: torch.Tensor,
):
test_data, size, scale_factor, compare_outputs = test_data

pipeline = EthosU85PipelineBI[input_t1](
Interpolate(size, scale_factor),
(test_data,),
aten_op,
run_on_fvp=True,
qtol=1,
use_to_edge_transform_and_lower=True,
)
if not compare_outputs:
pipeline.pop_stage(-1)
pipeline.run()


@common.parametrize("test_data", test_data_suite_Uxx)
@common.XfailIfNoCorstone320
def test_upsample_bilinear2d_vec_U85_BI_UpsamplingBilinear2d(
test_data: torch.Tensor,
):
test_data, size, scale_factor, compare_outputs = test_data

pipeline = EthosU85PipelineBI[input_t1](
UpsamplingBilinear2d(size, scale_factor),
(test_data,),
aten_op,
run_on_fvp=True,
qtol=1,
use_to_edge_transform_and_lower=True,
)
if not compare_outputs:
pipeline.pop_stage(-1)
pipeline.run()
1 change: 1 addition & 0 deletions backends/arm/tosa_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def filter_fn(node: torch.fx.Node) -> bool:

ops_to_not_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
torch.ops.aten.eye.default,
torch.ops.aten.linspace.default,
Expand Down
Loading