Skip to content

Commit e19677c

Browse files
Arm backend: Add squeeze-op (#5681)
Summary: - Add squeeze op and unittest - Update unsqueeze test Pull Request resolved: #5681 Reviewed By: digantdesai Differential Revision: D63637246 Pulled By: mergennachin fbshipit-source-id: bf068d0305ea9f499f33ed27b5050eb09829b353
1 parent 68548e5 commit e19677c

File tree

7 files changed

+279
-7
lines changed

7 files changed

+279
-7
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6666
exir_ops.edge.aten.clone.default,
6767
exir_ops.edge.aten.mean.dim,
6868
exir_ops.edge.aten.unsqueeze_copy.default,
69+
exir_ops.edge.aten.squeeze_copy.dims,
6970
operator.getitem,
7071
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
7172
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
op_sigmoid,
3333
op_slice,
3434
op_softmax,
35+
op_squeeze,
3536
op_sub,
3637
op_unsqueeze,
3738
op_view,

backends/arm/operators/op_squeeze.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import List
6+
7+
import serializer.tosa_serializer as ts
8+
import torch
9+
from executorch.backends.arm.operators.node_visitor import (
10+
NodeVisitor,
11+
register_node_visitor,
12+
)
13+
from executorch.backends.arm.tosa_mapping import TosaArg
14+
from executorch.backends.arm.tosa_utils import tosa_shape
15+
from serializer.tosa_serializer import TosaOp
16+
17+
18+
@register_node_visitor
19+
class SqueezeVisitor(NodeVisitor):
20+
target = "aten.squeeze_copy.dims"
21+
22+
def define_node(
23+
self,
24+
node: torch.fx.Node,
25+
tosa_graph: ts.TosaSerializer,
26+
inputs: List[TosaArg],
27+
output: TosaArg,
28+
is_quant_node: bool,
29+
) -> None:
30+
shape = inputs[0].shape
31+
rank = len(shape)
32+
# In some cases, e.g. torch.randn((1, 5, 1, 5)).squeeze(),
33+
# dims == [0, 1, 2, 3] even though all dims cannot be squeezed.
34+
# We need to verify that shape[dim] == 1 before squeezing the dim.
35+
dims = [dim % rank for dim in inputs[1].special if shape[dim] == 1]
36+
new_shape = [shape[i] for i in range(rank) if i not in dims]
37+
new_shape = tosa_shape(new_shape, output.dim_order)
38+
attr = ts.TosaSerializerAttribute()
39+
attr.ReshapeAttribute(new_shape)
40+
tosa_graph.addOperator(
41+
TosaOp.Op().RESHAPE, [inputs[0].name], [output.name], attr
42+
)

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
146146
torch.ops.aten.permute.default,
147147
torch.ops.aten.permute_copy.default,
148148
torch.ops.aten.squeeze.dim,
149+
torch.ops.aten.squeeze.dims,
150+
torch.ops.aten.squeeze.default,
149151
torch.ops.aten.squeeze_copy.dim,
152+
torch.ops.aten.unsqueeze.default,
150153
# TODO: remove?
151154
torch.ops.aten.adaptive_avg_pool2d.default,
152155
torch.ops.aten.view_copy.default,

backends/arm/test/ops/test_squeeze.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
#
8+
# Tests the squeeze op which squeezes a given dimension with size 1 into a lower ranked tensor.
9+
#
10+
11+
import unittest
12+
from typing import Optional, Tuple
13+
14+
import torch
15+
16+
from executorch.backends.arm.quantizer.arm_quantizer import (
17+
ArmQuantizer,
18+
get_symmetric_quantization_config,
19+
)
20+
from executorch.backends.arm.test import common
21+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
22+
23+
from executorch.backends.xnnpack.test.tester.tester import Quantize
24+
from executorch.exir.backend.compile_spec_schema import CompileSpec
25+
from parameterized import parameterized
26+
27+
28+
class TestSqueeze(unittest.TestCase):
29+
class SqueezeDim(torch.nn.Module):
30+
test_parameters: list[tuple[torch.Tensor, int]] = [
31+
(torch.randn(1, 1, 5), -2),
32+
(torch.randn(1, 2, 3, 1), 3),
33+
(torch.randn(1, 5, 1, 5), -2),
34+
]
35+
36+
def forward(self, x: torch.Tensor, dim: int):
37+
return x.squeeze(dim)
38+
39+
class SqueezeDims(torch.nn.Module):
40+
test_parameters: list[tuple[torch.Tensor, tuple[int]]] = [
41+
(torch.randn(1, 5, 5, 1), (0, -1)),
42+
(torch.randn(1, 5, 1, 5), (0, -2)),
43+
]
44+
45+
def forward(self, x: torch.Tensor, dims: tuple[int]):
46+
return x.squeeze(dims)
47+
48+
class Squeeze(torch.nn.Module):
49+
test_parameters: list[tuple[torch.Tensor]] = [
50+
(torch.randn(1, 5, 5, 1),),
51+
(torch.randn(1, 5, 1, 5),),
52+
]
53+
54+
def forward(self, x: torch.Tensor):
55+
return x.squeeze()
56+
57+
def _test_squeeze_tosa_MI_pipeline(
58+
self,
59+
module: torch.nn.Module,
60+
test_data: Tuple[torch.Tensor, Optional[tuple[int]]],
61+
export_target: str,
62+
):
63+
(
64+
ArmTester(
65+
module,
66+
example_inputs=test_data,
67+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
68+
)
69+
.export()
70+
.check_count({export_target: 1})
71+
.to_edge()
72+
.partition()
73+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
74+
.to_executorch()
75+
.run_method_and_compare_outputs(inputs=test_data)
76+
)
77+
78+
def _test_squeeze_tosa_BI_pipeline(
79+
self,
80+
module: torch.nn.Module,
81+
test_data: Tuple[torch.Tensor, Optional[tuple[int]]],
82+
export_target: str,
83+
):
84+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
85+
(
86+
ArmTester(
87+
module,
88+
example_inputs=test_data,
89+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
90+
)
91+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
92+
.export()
93+
.check_count({export_target: 1})
94+
.to_edge()
95+
.partition()
96+
.dump_artifact()
97+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
98+
.to_executorch()
99+
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
100+
)
101+
102+
def _test_squeeze_ethosu_BI_pipeline(
103+
self,
104+
compile_spec: CompileSpec,
105+
module: torch.nn.Module,
106+
test_data: Tuple[torch.Tensor, Optional[tuple[int]]],
107+
export_target: str,
108+
):
109+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
110+
(
111+
ArmTester(module, example_inputs=test_data, compile_spec=compile_spec)
112+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
113+
.export()
114+
.check_count({export_target: 1})
115+
.to_edge()
116+
.partition()
117+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
118+
.to_executorch()
119+
)
120+
121+
@parameterized.expand(Squeeze.test_parameters)
122+
def test_squeeze_tosa_MI(
123+
self,
124+
test_tensor: torch.Tensor,
125+
):
126+
self._test_squeeze_tosa_MI_pipeline(
127+
self.Squeeze(), (test_tensor,), "torch.ops.aten.squeeze.default"
128+
)
129+
130+
@parameterized.expand(Squeeze.test_parameters)
131+
def test_squeeze_tosa_BI(
132+
self,
133+
test_tensor: torch.Tensor,
134+
):
135+
self._test_squeeze_tosa_BI_pipeline(
136+
self.Squeeze(), (test_tensor,), "torch.ops.aten.squeeze.default"
137+
)
138+
139+
@parameterized.expand(Squeeze.test_parameters)
140+
def test_squeeze_u55_BI(
141+
self,
142+
test_tensor: torch.Tensor,
143+
):
144+
self._test_squeeze_ethosu_BI_pipeline(
145+
common.get_u55_compile_spec(permute_memory_to_nhwc=False),
146+
self.Squeeze(),
147+
(test_tensor,),
148+
"torch.ops.aten.squeeze.default",
149+
)
150+
151+
@parameterized.expand(Squeeze.test_parameters)
152+
def test_squeeze_u85_BI(
153+
self,
154+
test_tensor: torch.Tensor,
155+
):
156+
self._test_squeeze_ethosu_BI_pipeline(
157+
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
158+
self.Squeeze(),
159+
(test_tensor,),
160+
"torch.ops.aten.squeeze.default",
161+
)
162+
163+
@parameterized.expand(SqueezeDim.test_parameters)
164+
def test_squeeze_dim_tosa_MI(self, test_tensor: torch.Tensor, dim: int):
165+
self._test_squeeze_tosa_MI_pipeline(
166+
self.SqueezeDim(), (test_tensor, dim), "torch.ops.aten.squeeze.dim"
167+
)
168+
169+
@parameterized.expand(SqueezeDim.test_parameters)
170+
def test_squeeze_dim_tosa_BI(self, test_tensor: torch.Tensor, dim: int):
171+
self._test_squeeze_tosa_BI_pipeline(
172+
self.SqueezeDim(), (test_tensor, dim), "torch.ops.aten.squeeze.dim"
173+
)
174+
175+
@parameterized.expand(SqueezeDim.test_parameters)
176+
def test_squeeze_dim_u55_BI(self, test_tensor: torch.Tensor, dim: int):
177+
self._test_squeeze_ethosu_BI_pipeline(
178+
common.get_u55_compile_spec(permute_memory_to_nhwc=False),
179+
self.SqueezeDim(),
180+
(test_tensor, dim),
181+
"torch.ops.aten.squeeze.dim",
182+
)
183+
184+
@parameterized.expand(SqueezeDim.test_parameters)
185+
def test_squeeze_dim_u85_BI(self, test_tensor: torch.Tensor, dim: int):
186+
self._test_squeeze_ethosu_BI_pipeline(
187+
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
188+
self.SqueezeDim(),
189+
(test_tensor, dim),
190+
"torch.ops.aten.squeeze.dim",
191+
)
192+
193+
@parameterized.expand(SqueezeDims.test_parameters)
194+
def test_squeeze_dims_tosa_MI(self, test_tensor: torch.Tensor, dims: tuple[int]):
195+
self._test_squeeze_tosa_MI_pipeline(
196+
self.SqueezeDims(), (test_tensor, dims), "torch.ops.aten.squeeze.dims"
197+
)
198+
199+
@parameterized.expand(SqueezeDims.test_parameters)
200+
def test_squeeze_dims_tosa_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
201+
self._test_squeeze_tosa_BI_pipeline(
202+
self.SqueezeDims(), (test_tensor, dims), "torch.ops.aten.squeeze.dims"
203+
)
204+
205+
@parameterized.expand(SqueezeDims.test_parameters)
206+
def test_squeeze_dims_u55_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
207+
self._test_squeeze_ethosu_BI_pipeline(
208+
common.get_u55_compile_spec(permute_memory_to_nhwc=False),
209+
self.SqueezeDims(),
210+
(test_tensor, dims),
211+
"torch.ops.aten.squeeze.dims",
212+
)
213+
214+
@parameterized.expand(SqueezeDims.test_parameters)
215+
def test_squeeze_dims_u85_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
216+
self._test_squeeze_ethosu_BI_pipeline(
217+
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
218+
self.SqueezeDims(),
219+
(test_tensor, dims),
220+
"torch.ops.aten.squeeze.dims",
221+
)

backends/arm/test/ops/test_unsqueeze.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828
class TestSimpleUnsqueeze(unittest.TestCase):
2929
class Unsqueeze(torch.nn.Module):
30-
shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 5), (5, 5, 5)]
31-
test_parameters: list[tuple[torch.Tensor]] = [(torch.ones(n),) for n in shapes]
30+
shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 5), (5, 4, 3)]
31+
test_parameters: list[tuple[torch.Tensor]] = [(torch.randn(n),) for n in shapes]
3232

3333
def forward(self, x: torch.Tensor, dim):
3434
return x.unsqueeze(dim)
@@ -40,7 +40,7 @@ def _test_unsqueeze_tosa_MI_pipeline(
4040
ArmTester(
4141
module,
4242
example_inputs=test_data,
43-
compile_spec=common.get_tosa_compile_spec(),
43+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
4444
)
4545
.export()
4646
.check_count({"torch.ops.aten.unsqueeze.default": 1})
@@ -59,7 +59,7 @@ def _test_unsqueeze_tosa_BI_pipeline(
5959
ArmTester(
6060
module,
6161
example_inputs=test_data,
62-
compile_spec=common.get_tosa_compile_spec(),
62+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
6363
)
6464
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
6565
.export()
@@ -105,11 +105,15 @@ def test_unsqueeze_tosa_BI(self, test_tensor: torch.Tensor):
105105
@parameterized.expand(Unsqueeze.test_parameters)
106106
def test_unsqueeze_u55_BI(self, test_tensor: torch.Tensor):
107107
self._test_unsqueeze_ethosu_BI_pipeline(
108-
common.get_u55_compile_spec(), self.Unsqueeze(), (test_tensor, 0)
108+
common.get_u55_compile_spec(permute_memory_to_nhwc=False),
109+
self.Unsqueeze(),
110+
(test_tensor, 0),
109111
)
110112

111113
@parameterized.expand(Unsqueeze.test_parameters)
112114
def test_unsqueeze_u85_BI(self, test_tensor: torch.Tensor):
113115
self._test_unsqueeze_ethosu_BI_pipeline(
114-
common.get_u85_compile_spec(), self.Unsqueeze(), (test_tensor, 0)
116+
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
117+
self.Unsqueeze(),
118+
(test_tensor, 0),
115119
)

backends/arm/test/runner_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def run_tosa_ref_model(
407407
def prep_data_for_save(
408408
data, is_quantized: bool, input_name: str, quant_param: QuantizationParams
409409
):
410-
data_np = data.detach().numpy().astype(np.float32)
410+
data_np = np.array(data.detach(), order="C").astype(np.float32)
411411

412412
if is_quantized:
413413
assert (

0 commit comments

Comments
 (0)