Skip to content

Commit ef9c07f

Browse files
authored
Add unsqueeze op to Arm backend
Differential Revision: D61718607 Pull Request resolved: #4833
1 parent e3ac39a commit ef9c07f

File tree

4 files changed

+156
-0
lines changed

4 files changed

+156
-0
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5959
exir_ops.edge.aten.view_copy.default,
6060
exir_ops.edge.aten.clone.default,
6161
exir_ops.edge.aten.mean.dim,
62+
exir_ops.edge.aten.unsqueeze_copy.default,
6263
operator.getitem,
6364
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
6465
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,6 @@
2727
op_slice,
2828
op_softmax,
2929
op_sub,
30+
op_unsqueeze,
3031
op_view,
3132
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
# Follows this specification: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
7+
8+
import serializer.tosa_serializer as ts
9+
import torch.fx
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_utils import tosa_shape
16+
from serializer.tosa_serializer import TosaOp
17+
18+
19+
@register_node_visitor
20+
class UnsqueezeVisitor(NodeVisitor):
21+
target = "aten.unsqueeze_copy.default"
22+
23+
def __init__(self, *args):
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
tosa_graph: ts.TosaSerializer,
30+
inputs: list[TosaArg],
31+
output: TosaArg,
32+
is_quant_node: bool,
33+
) -> None:
34+
35+
dim = inputs[1].number
36+
shape = inputs[0].shape
37+
rank = len(shape)
38+
39+
assert -rank - 1 <= dim < rank + 1
40+
if dim < 0:
41+
dim = dim + rank + 1
42+
43+
new_shape = list(shape)
44+
new_shape.insert(dim, 1)
45+
new_shape = tosa_shape(new_shape, output.dim_order)
46+
47+
attr = ts.TosaSerializerAttribute()
48+
attr.ReshapeAttribute(new_shape)
49+
tosa_graph.addOperator(
50+
TosaOp.Op().RESHAPE, [inputs[0].name], [output.name], attr
51+
)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
#
8+
# Tests the unsqueeze op which copies the data of the input tensor (possibly with new data format)
9+
#
10+
11+
import unittest
12+
from typing import Sequence, Tuple
13+
14+
import torch
15+
16+
from executorch.backends.arm.quantizer.arm_quantizer import (
17+
ArmQuantizer,
18+
get_symmetric_quantization_config,
19+
)
20+
from executorch.backends.arm.test import common
21+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
22+
23+
from executorch.backends.xnnpack.test.tester.tester import Quantize
24+
from parameterized import parameterized
25+
26+
27+
class TestSimpleUnsqueeze(unittest.TestCase):
28+
class Unsqueeze(torch.nn.Module):
29+
shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 5), (5, 5, 5)]
30+
test_parameters: list[tuple[torch.Tensor]] = [(torch.ones(n),) for n in shapes]
31+
32+
def forward(self, x: torch.Tensor, dim):
33+
return x.unsqueeze(dim)
34+
35+
def _test_unsqueeze_tosa_MI_pipeline(
36+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int]
37+
):
38+
(
39+
ArmTester(
40+
module,
41+
example_inputs=test_data,
42+
compile_spec=common.get_tosa_compile_spec(),
43+
)
44+
.export()
45+
.check_count({"torch.ops.aten.unsqueeze.default": 1})
46+
.to_edge()
47+
.partition()
48+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
49+
.to_executorch()
50+
.run_method_and_compare_outputs(inputs=test_data)
51+
)
52+
53+
def _test_unsqueeze_tosa_BI_pipeline(
54+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int]
55+
):
56+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
57+
(
58+
ArmTester(
59+
module,
60+
example_inputs=test_data,
61+
compile_spec=common.get_tosa_compile_spec(),
62+
)
63+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
64+
.export()
65+
.check_count({"torch.ops.aten.unsqueeze.default": 1})
66+
.to_edge()
67+
.partition()
68+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
69+
.to_executorch()
70+
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
71+
)
72+
73+
def _test_unsqueeze_tosa_u55_pipeline(
74+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int]
75+
):
76+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
77+
(
78+
ArmTester(
79+
module,
80+
example_inputs=test_data,
81+
compile_spec=common.get_u55_compile_spec(),
82+
)
83+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
84+
.export()
85+
.check_count({"torch.ops.aten.unsqueeze.default": 1})
86+
.to_edge()
87+
.partition()
88+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
89+
.to_executorch()
90+
)
91+
92+
@parameterized.expand(Unsqueeze.test_parameters)
93+
def test_unsqueeze_tosa_MI(self, test_tensor: torch.Tensor):
94+
for i in range(-test_tensor.dim() - 1, test_tensor.dim() + 1):
95+
self._test_unsqueeze_tosa_MI_pipeline(self.Unsqueeze(), (test_tensor, i))
96+
97+
@parameterized.expand(Unsqueeze.test_parameters)
98+
def test_unsqueeze_tosa_BI(self, test_tensor: torch.Tensor):
99+
self._test_unsqueeze_tosa_BI_pipeline(self.Unsqueeze(), (test_tensor, 0))
100+
101+
@parameterized.expand(Unsqueeze.test_parameters)
102+
def test_unsqueeze_u55_BI(self, test_tensor: torch.Tensor):
103+
self._test_unsqueeze_tosa_u55_pipeline(self.Unsqueeze(), (test_tensor, 0))

0 commit comments

Comments
 (0)