Skip to content

Commit 0222074

Browse files
authored
Arm backend: Support aten.full_like (#8455)
The full_like is replaced with a full. Full_like is annotated with one_to_one annotation since the output value is not dependent on the input, except for shape and dtype. Update reference_model SHA id to include commit that adds support of boolean input. Signed-off-by: Erik Lundell <[email protected]>
1 parent 56f3a55 commit 0222074

File tree

7 files changed

+78
-16
lines changed

7 files changed

+78
-16
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -18,6 +18,9 @@
1818
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
1919
ConvertExpandCopyToRepeatPass,
2020
)
21+
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
22+
ConvertFullLikeToFullPass,
23+
)
2124
from executorch.backends.arm._passes.convert_split_to_slice import (
2225
ConvertSplitToSlicePass,
2326
)
@@ -97,6 +100,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
97100
self.add_pass(ConvertMmToBmmPass())
98101
self.add_pass(DecomposeLinearPass())
99102
self.add_pass(ConvertMeanDimToAveragePoolPass())
103+
self.add_pass(ConvertFullLikeToFullPass())
100104

101105
self.add_pass(AnnotateDecomposedMatmulPass())
102106
self.add_pass(QuantizeOperatorArguments())
@@ -135,7 +139,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
135139
self.add_pass(ConvertMeanDimToAveragePoolPass())
136140
self.add_pass(DecomposeDivPass())
137141
self.add_pass(DecomposeSoftmaxesPass())
138-
142+
self.add_pass(ConvertFullLikeToFullPass())
139143
self.add_pass(AnnotateDecomposedMatmulPass())
140144
self.add_pass(QuantizeOperatorArguments())
141145
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.exir.dialects._ops import ops as exir_ops
7+
from executorch.exir.pass_base import ExportPass
8+
9+
10+
class ConvertFullLikeToFullPass(ExportPass):
11+
"""As per the full_like pytorch documentation,
12+
`torch.full_like(input, fill_value)` is equivalent to
13+
`torch.full(input.size(),
14+
fill_value,
15+
dtype=input.dtype,
16+
layout=input.layout,
17+
device=input.device
18+
)`
19+
Skip layout and device since it's not relevant for our backend.
20+
"""
21+
22+
def call_operator(self, op, args, kwargs, meta):
23+
if op not in [
24+
exir_ops.edge.aten.full_like.default,
25+
]:
26+
return super().call_operator(op, args, kwargs, meta)
27+
28+
tensor = args[0].data
29+
full_args = (list(tensor.shape), args[1])
30+
full_kwargs = {"dtype": tensor.dtype}
31+
return super().call_operator(
32+
exir_ops.edge.aten.full.default, full_args, full_kwargs, meta
33+
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
105105
exir_ops.edge.aten.linear.default,
106106
exir_ops.edge.aten.split_with_sizes_copy.default,
107107
exir_ops.edge.aten.full.default,
108+
exir_ops.edge.aten.full_like.default,
108109
exir_ops.edge.aten.ge.Tensor,
109110
exir_ops.edge.aten.gt.Tensor,
110111
exir_ops.edge.aten.le.Tensor,

backends/arm/quantizer/quantization_annotator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _match_pattern(
134134
torch.ops.aten.sum.dim_IntList,
135135
torch.ops.aten.hardsigmoid.default,
136136
torch.ops.aten.hardswish.default,
137+
torch.ops.aten.full_like.default,
137138
]
138139

139140
_one_to_one_shared_input_qspec = [
@@ -379,3 +380,11 @@ def annotate_graph( # type: ignore[return]
379380
_annotate_output(node, quant_properties.quant_output)
380381

381382
arm_quantizer_utils.mark_node_as_annotated(node) # type: ignore[attr-defined]
383+
384+
# Quantization does not allow kwargs for some reason.
385+
# Remove from ops we know have and where we know it does not break anything.
386+
if node.target in [
387+
torch.ops.aten.full_like.default,
388+
torch.ops.aten.full.default,
389+
]:
390+
node.kwargs = {}

backends/arm/test/models/test_conformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class TestConformer(unittest.TestCase):
3131
# .to_executorch step, i.e. after Arm partitioner.
3232
ops_after_partitioner = {
3333
"executorch_exir_dialects_edge__ops_aten_arange_start_step": 1,
34-
"executorch_exir_dialects_edge__ops_aten_full_like_default": 4,
3534
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
3635
"executorch_exir_dialects_edge__ops_aten_mul_Scalar": 4,
3736
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,

backends/arm/test/ops/test_full.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -36,8 +35,8 @@ def forward(self, x: torch.Tensor):
3635
return torch.full((2, 2, 3, 3), 4.5, dtype=torch.float32) + x
3736

3837
class AddVariableFull(torch.nn.Module):
39-
sizes = [
40-
(5),
38+
sizes: list[tuple[int, ...]] = [
39+
(5,),
4140
(5, 5),
4241
(5, 5, 5),
4342
(1, 5, 5, 5),
@@ -48,6 +47,21 @@ def forward(self, x: torch.Tensor, y):
4847
# Input + a full with the shape from the input and a given value 'y'.
4948
return x + torch.full(x.shape, y)
5049

50+
class FullLike(torch.nn.Module):
51+
"""Since full_like is replaced with full, we only need to test on reference model, not FVP."""
52+
53+
test_parameters = [
54+
((torch.randn(2, 2, 2, 2) * 50, 3.2),),
55+
((torch.randn(2, 2, 2, 2) * 50, 3),),
56+
(((torch.randn(2, 2, 2, 2) * 50).to(torch.int32), 3.2),),
57+
(((torch.randn(2, 2, 2, 2) * 50).to(torch.int32), 3),),
58+
]
59+
60+
def forward(self, input_tensor: torch.Tensor, value):
61+
# Our backend can't handle tensors without users, which input_tensor doesn't have
62+
# when the full_like is converted to a full. Therefore involve it in the output.
63+
return input_tensor + torch.full_like(input_tensor, value)
64+
5165
def _test_full_tosa_MI_pipeline(
5266
self,
5367
module: torch.nn.Module,
@@ -63,9 +77,7 @@ def _test_full_tosa_MI_pipeline(
6377
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
6478
)
6579
.export()
66-
.check_count({"torch.ops.aten.full.default": 1})
67-
.to_edge()
68-
.partition()
80+
.to_edge_transform_and_lower()
6981
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
7082
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
7183
.to_executorch()
@@ -85,9 +97,7 @@ def _test_full_tosa_BI_pipeline(
8597
)
8698
.quantize()
8799
.export()
88-
.check_count({"torch.ops.aten.full.default": 1})
89-
.to_edge()
90-
.partition()
100+
.to_edge_transform_and_lower()
91101
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
92102
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
93103
.to_executorch()
@@ -101,9 +111,7 @@ def _test_full_tosa_ethos_pipeline(
101111
ArmTester(module, example_inputs=test_data, compile_spec=compile_spec)
102112
.quantize()
103113
.export()
104-
.check_count({"torch.ops.aten.full.default": 1})
105-
.to_edge()
106-
.partition()
114+
.to_edge_transform_and_lower()
107115
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
108116
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
109117
.to_executorch()
@@ -129,6 +137,10 @@ def test_const_full_tosa_MI(self):
129137
_input = torch.rand((2, 2, 3, 3)) * 10
130138
self._test_full_tosa_MI_pipeline(self.AddConstFull(), (_input,))
131139

140+
@parameterized.expand(FullLike.test_parameters)
141+
def test_full_like_tosa_MI(self, test_tensor: Tuple):
142+
self._test_full_tosa_MI_pipeline(self.FullLike(), test_tensor)
143+
132144
def test_const_full_nhwc_tosa_BI(self):
133145
_input = torch.rand((2, 2, 3, 3)) * 10
134146
self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,))
@@ -143,6 +155,10 @@ def test_full_tosa_MI(self, test_tensor: Tuple):
143155
def test_full_tosa_BI(self, test_tensor: Tuple):
144156
self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor)
145157

158+
@parameterized.expand(FullLike.test_parameters)
159+
def test_full_like_tosa_BI(self, test_tensor: Tuple):
160+
self._test_full_tosa_BI_pipeline(self.FullLike(), test_tensor)
161+
146162
@parameterized.expand(AddVariableFull.test_parameters)
147163
@pytest.mark.corstone_fvp
148164
def test_full_u55_BI(self, test_tensor: Tuple):

examples/arm/setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ ethos_u_base_rev="24.08"
6161

6262
# tosa reference model
6363
tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model"
64-
tosa_reference_model_rev="v0.80.1"
64+
tosa_reference_model_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a"
6565

6666
# vela
6767
vela_repo_url="https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela"

0 commit comments

Comments
 (0)