Skip to content

Commit 4c68660

Browse files
Tessilfreddan80
authored andcommitted
Add support for upsample_nearest2d op in the Arm backend
Change-Id: Id0b742214e5432957b2f573b4218f09a4d9734e4
1 parent 667f600 commit 4c68660

File tree

9 files changed

+354
-1
lines changed

9 files changed

+354
-1
lines changed

backends/arm/arm_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6969
exir_ops.edge.aten.sub.Tensor,
7070
exir_ops.edge.aten.sum.dim_IntList,
7171
exir_ops.edge.aten.tanh.default,
72+
exir_ops.edge.aten.upsample_nearest2d.vec,
7273
exir_ops.edge.aten.view_copy.default,
7374
exir_ops.edge.aten.clone.default,
7475
exir_ops.edge.aten.mean.dim,
@@ -144,5 +145,6 @@ def ops_to_not_decompose(
144145
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
145146
ops_to_not_decompose = [
146147
torch.ops.aten.linear.default,
148+
torch.ops.aten.upsample_nearest2d.vec,
147149
]
148150
return (ops_to_not_decompose, None)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,6 @@
3737
op_tanh,
3838
op_transpose,
3939
op_unsqueeze,
40+
op_upsample_nearest2d,
4041
op_view,
4142
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import List
6+
7+
import serializer.tosa_serializer as ts
8+
import torch
9+
from executorch.backends.arm.operators.node_visitor import (
10+
NodeVisitor,
11+
register_node_visitor,
12+
)
13+
from executorch.backends.arm.tosa_mapping import TosaArg
14+
from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape
15+
from serializer.tosa_serializer import TosaOp
16+
17+
from tosa.ResizeMode import ResizeMode
18+
19+
20+
@register_node_visitor
21+
class UpsampleNearest2dVisitor(NodeVisitor):
22+
target = "aten.upsample_nearest2d.vec"
23+
24+
def __init__(self, *args):
25+
super().__init__(*args)
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
tosa_graph: ts.TosaSerializer,
31+
inputs: List[TosaArg],
32+
output: TosaArg,
33+
is_quant_node: bool,
34+
) -> None:
35+
assert (
36+
inputs[0].shape is not None and output.shape is not None
37+
), "Only static shapes are supported"
38+
39+
# tosa_shape output is NHWC, take HW
40+
input_size_yx = torch.tensor(
41+
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
42+
)
43+
# Ignore scale and size parameters, directly use the output size as
44+
# we only support static shapes currently
45+
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])
46+
47+
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
48+
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
49+
)
50+
51+
def in_int16_range(x):
52+
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
53+
54+
assert in_int16_range(scale_n_yx)
55+
assert in_int16_range(scale_d_yx)
56+
assert in_int16_range(border_yx)
57+
58+
attr = ts.TosaSerializerAttribute()
59+
attr.ResizeAttribute(
60+
scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]],
61+
offset=offset_yx.tolist(),
62+
border=border_yx.tolist(),
63+
mode=ResizeMode.NEAREST,
64+
)
65+
66+
tosa_graph.addOperator(
67+
TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr
68+
)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ class ArmQuantizer(Quantizer):
270270
"mm",
271271
"one_to_one",
272272
"generic",
273+
"upsample_nearest2d",
273274
]
274275

275276
def __init__(self) -> None:

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,5 @@ def decorator(annotator: AnnotatorType):
5959
mul_annotator,
6060
one_to_one_annotator,
6161
sub_annotator,
62+
upsample_nearest2d_annotator,
6263
)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import itertools
7+
from typing import Callable, List, Optional
8+
9+
import torch
10+
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
11+
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
12+
from torch.ao.quantization.quantizer import (
13+
QuantizationAnnotation,
14+
SharedQuantizationSpec,
15+
)
16+
from torch.fx import Node
17+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
18+
19+
20+
def _filter_upsample_nearest2d(filter_fn: Optional[Callable[[Node], bool]] = None):
21+
def filter(node: Node):
22+
is_upsample = node.target == torch.ops.aten.upsample_nearest2d.vec
23+
if filter_fn is None:
24+
return is_upsample
25+
else:
26+
return is_upsample and filter_fn(node)
27+
28+
return filter
29+
30+
31+
@register_annotator("upsample_nearest2d")
32+
def _annotate_upsample_nearest2d(
33+
gm: torch.fx.GraphModule,
34+
quantization_config: QuantizationConfig,
35+
filter_fn: Optional[Callable[[Node], bool]] = None,
36+
) -> Optional[List[List[Node]]]:
37+
module_partitions = get_source_partitions(
38+
gm.graph,
39+
[
40+
torch.nn.UpsamplingNearest2d,
41+
torch.nn.Upsample,
42+
torch.nn.functional.interpolate,
43+
],
44+
_filter_upsample_nearest2d(filter_fn),
45+
)
46+
upsample_partitions = list(
47+
itertools.chain.from_iterable(module_partitions.values())
48+
)
49+
annotated_partitions = []
50+
51+
for upsample_partition in upsample_partitions:
52+
annotated_partitions.append(upsample_partition.nodes)
53+
54+
assert len(upsample_partition.nodes) == 1
55+
upsample_node = upsample_partition.nodes[0]
56+
57+
input_act = upsample_node.args[0]
58+
assert isinstance(input_act, Node)
59+
60+
input_act_qspec = quantization_config.get_input_act_qspec()
61+
output_act_qspec = SharedQuantizationSpec((input_act, upsample_node))
62+
63+
upsample_node.meta["quantization_annotation"] = QuantizationAnnotation(
64+
input_qspec_map={
65+
input_act: input_act_qspec,
66+
},
67+
output_qspec=output_act_qspec,
68+
_annotated=True,
69+
)
70+
71+
return annotated_partitions
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Optional, Tuple
10+
11+
import torch
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
14+
from parameterized import parameterized
15+
16+
17+
test_data_suite = [
18+
# (test_name, test_data, size, scale_factor, compare_outputs)
19+
("rand_double_scale", torch.rand(2, 4, 8, 3), None, 2.0, True),
20+
("rand_double_scale_one_dim", torch.rand(2, 4, 8, 3), None, (1.0, 2.0), True),
21+
("rand_double_size", torch.rand(2, 4, 8, 3), (16, 6), None, True),
22+
("rand_one_double_scale", torch.rand(2, 4, 1, 1), None, 2.0, True),
23+
("rand_one_double_size", torch.rand(2, 4, 1, 1), (2, 2), None, True),
24+
("rand_one_same_scale", torch.rand(2, 4, 1, 1), None, 1.0, True),
25+
("rand_one_same_size", torch.rand(2, 4, 1, 1), (1, 1), None, True),
26+
# Can't compare outputs as the rounding when selecting the nearest pixel is
27+
# different between PyTorch and TOSA. Just check the legalization went well.
28+
# TODO Improve the test infrastructure to support more in depth verification
29+
# of the TOSA legalization results.
30+
("rand_half_scale", torch.rand(2, 4, 8, 6), None, 0.5, False),
31+
("rand_half_size", torch.rand(2, 4, 8, 6), (4, 3), None, False),
32+
("rand_one_and_half_scale", torch.rand(2, 4, 8, 3), None, 1.5, False),
33+
("rand_one_and_half_size", torch.rand(2, 4, 8, 3), (12, 4), None, False),
34+
]
35+
36+
37+
class TestUpsampleNearest2d(unittest.TestCase):
38+
class UpsamplingNearest2d(torch.nn.Module):
39+
def __init__(
40+
self,
41+
size: Optional[Tuple[int]],
42+
scale_factor: Optional[float | Tuple[float]],
43+
):
44+
super().__init__()
45+
self.upsample = torch.nn.UpsamplingNearest2d( # noqa: TOR101
46+
size=size, scale_factor=scale_factor
47+
)
48+
49+
def forward(self, x):
50+
return self.upsample(x)
51+
52+
class Upsample(torch.nn.Module):
53+
def __init__(
54+
self,
55+
size: Optional[Tuple[int]],
56+
scale_factor: Optional[float | Tuple[float]],
57+
):
58+
super().__init__()
59+
self.upsample = torch.nn.Upsample(
60+
size=size, scale_factor=scale_factor, mode="nearest"
61+
)
62+
63+
def forward(self, x):
64+
return self.upsample(x)
65+
66+
class Interpolate(torch.nn.Module):
67+
def __init__(
68+
self,
69+
size: Optional[Tuple[int]],
70+
scale_factor: Optional[float | Tuple[float]],
71+
):
72+
super().__init__()
73+
self.upsample = lambda x: torch.nn.functional.interpolate(
74+
x, size=size, scale_factor=scale_factor, mode="nearest"
75+
)
76+
77+
def forward(self, x):
78+
return self.upsample(x)
79+
80+
def _test_upsample_nearest_2d_tosa_MI_pipeline(
81+
self,
82+
module: torch.nn.Module,
83+
test_data: Tuple[torch.tensor],
84+
compare_outputs: bool,
85+
):
86+
tester = (
87+
ArmTester(
88+
module,
89+
example_inputs=test_data,
90+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
91+
)
92+
.export()
93+
.check(["torch.ops.aten.upsample_nearest2d.vec"])
94+
.check_not(["torch.ops.quantized_decomposed"])
95+
.to_edge_transform_and_lower()
96+
.check_not(["torch.ops.aten.upsample_nearest2d.vec"])
97+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
98+
.to_executorch()
99+
)
100+
101+
if compare_outputs:
102+
tester.run_method_and_compare_outputs(inputs=test_data)
103+
104+
def _test_upsample_nearest_2d_tosa_BI_pipeline(
105+
self,
106+
module: torch.nn.Module,
107+
test_data: Tuple[torch.tensor],
108+
compare_outputs: bool,
109+
):
110+
tester = (
111+
ArmTester(
112+
module,
113+
example_inputs=test_data,
114+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
115+
)
116+
.quantize()
117+
.export()
118+
.check(["torch.ops.aten.upsample_nearest2d.vec"])
119+
.check(["torch.ops.quantized_decomposed"])
120+
.to_edge_transform_and_lower()
121+
.check_not(["torch.ops.aten.upsample_nearest2d.vec"])
122+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
123+
.to_executorch()
124+
)
125+
126+
if compare_outputs:
127+
tester.run_method_and_compare_outputs(inputs=test_data)
128+
129+
@parameterized.expand(test_data_suite)
130+
def test_upsample_nearest_2d_tosa_MI(
131+
self,
132+
test_name: str,
133+
test_data: torch.Tensor,
134+
size: Optional[Tuple[int]],
135+
scale_factor: Optional[float | Tuple[float]],
136+
compare_outputs: bool,
137+
):
138+
self._test_upsample_nearest_2d_tosa_MI_pipeline(
139+
self.UpsamplingNearest2d(size, scale_factor), (test_data,), compare_outputs
140+
)
141+
self._test_upsample_nearest_2d_tosa_MI_pipeline(
142+
self.Upsample(size, scale_factor), (test_data,), compare_outputs
143+
)
144+
self._test_upsample_nearest_2d_tosa_MI_pipeline(
145+
self.Interpolate(size, scale_factor), (test_data,), compare_outputs
146+
)
147+
148+
@parameterized.expand(test_data_suite)
149+
def test_upsample_nearest_2d_tosa_BI(
150+
self,
151+
test_name: str,
152+
test_data: torch.Tensor,
153+
size: Optional[Tuple[int]],
154+
scale_factor: Optional[float | Tuple[float]],
155+
compare_outputs: bool,
156+
):
157+
self._test_upsample_nearest_2d_tosa_BI_pipeline(
158+
self.UpsamplingNearest2d(size, scale_factor), (test_data,), compare_outputs
159+
)
160+
self._test_upsample_nearest_2d_tosa_BI_pipeline(
161+
self.Upsample(size, scale_factor), (test_data,), compare_outputs
162+
)
163+
self._test_upsample_nearest_2d_tosa_BI_pipeline(
164+
self.Interpolate(size, scale_factor), (test_data,), compare_outputs
165+
)

backends/arm/test/tester/arm_tester.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ def run_method_and_compare_outputs(
287287
inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data.
288288
The default is random data.
289289
"""
290-
291290
edge_stage = self.stages[self.stage_name(tester.ToEdge)]
292291
if edge_stage is None:
293292
edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)]

backends/arm/tosa_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,48 @@ def expand_dims(
298298
build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name)
299299

300300
return intermediate
301+
302+
303+
def get_resize_parameters(
304+
input_size: torch.Tensor,
305+
output_size: torch.Tensor,
306+
resize_mode: int,
307+
align_corners: bool,
308+
):
309+
"""Get the tosa.resize parameters based on the input and output size.
310+
311+
Args:
312+
input_size (torch.Tensor): Size of the input
313+
output_size (torch.Tensor): Size of the output
314+
resize_mode (tosa.ResizeMode): The TOSA resize mode
315+
align_corners (bool): Align the corners pixels of the input and output
316+
317+
Returns:
318+
scale_n (torch.Tensor), scale_d (torch.Tensor),
319+
offset (torch.Tensor), border (torch.Tensor)
320+
"""
321+
assert torch.all(input_size > 0)
322+
assert torch.all(output_size > 0)
323+
324+
scale_n = torch.tensor(
325+
[
326+
so - 1 if align_corners and si > 1 and so > 1 else so
327+
for si, so in zip(input_size, output_size)
328+
]
329+
)
330+
scale_d = torch.tensor(
331+
[
332+
si - 1 if align_corners and si > 1 and so > 1 else si
333+
for si, so in zip(input_size, output_size)
334+
]
335+
)
336+
337+
gcd = torch.gcd(scale_n, scale_d)
338+
scale_n = scale_n // gcd
339+
scale_d = scale_d // gcd
340+
341+
# No half-pixel centre support in PyTorch, no offset needed
342+
offset = torch.zeros_like(input_size)
343+
border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset
344+
345+
return scale_n, scale_d, offset, border

0 commit comments

Comments
 (0)