Skip to content

Arm backend: Support aten.full_like #8455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -18,6 +18,9 @@
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
ConvertExpandCopyToRepeatPass,
)
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
ConvertFullLikeToFullPass,
)
from executorch.backends.arm._passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
Expand Down Expand Up @@ -95,6 +98,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(ConvertFullLikeToFullPass())

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
Expand Down Expand Up @@ -133,7 +137,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeSoftmaxesPass())

self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
Expand Down
33 changes: 33 additions & 0 deletions backends/arm/_passes/convert_full_like_to_full_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class ConvertFullLikeToFullPass(ExportPass):
"""As per the full_like pytorch documentation,
`torch.full_like(input, fill_value)` is equivalent to
`torch.full(input.size(),
fill_value,
dtype=input.dtype,
layout=input.layout,
device=input.device
)`
Skip layout and device since it's not relevant for our backend.
"""

def call_operator(self, op, args, kwargs, meta):
if op not in [
exir_ops.edge.aten.full_like.default,
]:
return super().call_operator(op, args, kwargs, meta)

tensor = args[0].data
full_args = (list(tensor.shape), args[1])
full_kwargs = {"dtype": tensor.dtype}
return super().call_operator(
exir_ops.edge.aten.full.default, full_args, full_kwargs, meta
)
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.le.Tensor,
Expand Down
9 changes: 9 additions & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def _match_pattern(
torch.ops.aten.sum.dim_IntList,
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
torch.ops.aten.full_like.default,
]

_one_to_one_shared_input_qspec = [
Expand Down Expand Up @@ -379,3 +380,11 @@ def annotate_graph( # type: ignore[return]
_annotate_output(node, quant_properties.quant_output)

arm_quantizer_utils.mark_node_as_annotated(node) # type: ignore[attr-defined]

# Quantization does not allow kwargs for some reason.
# Remove from ops we know have and where we know it does not break anything.
if node.target in [
torch.ops.aten.full_like.default,
torch.ops.aten.full.default,
]:
node.kwargs = {}
1 change: 0 additions & 1 deletion backends/arm/test/models/test_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class TestConformer(unittest.TestCase):
# .to_executorch step, i.e. after Arm partitioner.
ops_after_partitioner = {
"executorch_exir_dialects_edge__ops_aten_arange_start_step": 1,
"executorch_exir_dialects_edge__ops_aten_full_like_default": 4,
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
"executorch_exir_dialects_edge__ops_aten_mul_Scalar": 4,
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
Expand Down
40 changes: 28 additions & 12 deletions backends/arm/test/ops/test_full.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -36,8 +35,8 @@ def forward(self, x: torch.Tensor):
return torch.full((2, 2, 3, 3), 4.5, dtype=torch.float32) + x

class AddVariableFull(torch.nn.Module):
sizes = [
(5),
sizes: list[tuple[int, ...]] = [
(5,),
(5, 5),
(5, 5, 5),
(1, 5, 5, 5),
Expand All @@ -48,6 +47,21 @@ def forward(self, x: torch.Tensor, y):
# Input + a full with the shape from the input and a given value 'y'.
return x + torch.full(x.shape, y)

class FullLike(torch.nn.Module):
"""Since full_like is replaced with full, we only need to test on reference model, not FVP."""

test_parameters = [
((torch.randn(2, 2, 2, 2) * 50, 3.2),),
((torch.randn(2, 2, 2, 2) * 50, 3),),
(((torch.randn(2, 2, 2, 2) * 50).to(torch.int32), 3.2),),
(((torch.randn(2, 2, 2, 2) * 50).to(torch.int32), 3),),
]

def forward(self, input_tensor: torch.Tensor, value):
# Our backend can't handle tensors without users, which input_tensor doesn't have
# when the full_like is converted to a full. Therefore involve it in the output.
return input_tensor + torch.full_like(input_tensor, value)

def _test_full_tosa_MI_pipeline(
self,
module: torch.nn.Module,
Expand All @@ -63,9 +77,7 @@ def _test_full_tosa_MI_pipeline(
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
)
.export()
.check_count({"torch.ops.aten.full.default": 1})
.to_edge()
.partition()
.to_edge_transform_and_lower()
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
Expand All @@ -85,9 +97,7 @@ def _test_full_tosa_BI_pipeline(
)
.quantize()
.export()
.check_count({"torch.ops.aten.full.default": 1})
.to_edge()
.partition()
.to_edge_transform_and_lower()
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
Expand All @@ -101,9 +111,7 @@ def _test_full_tosa_ethos_pipeline(
ArmTester(module, example_inputs=test_data, compile_spec=compile_spec)
.quantize()
.export()
.check_count({"torch.ops.aten.full.default": 1})
.to_edge()
.partition()
.to_edge_transform_and_lower()
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
Expand All @@ -129,6 +137,10 @@ def test_const_full_tosa_MI(self):
_input = torch.rand((2, 2, 3, 3)) * 10
self._test_full_tosa_MI_pipeline(self.AddConstFull(), (_input,))

@parameterized.expand(FullLike.test_parameters)
def test_full_like_tosa_MI(self, test_tensor: Tuple):
self._test_full_tosa_MI_pipeline(self.FullLike(), test_tensor)

def test_const_full_nhwc_tosa_BI(self):
_input = torch.rand((2, 2, 3, 3)) * 10
self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,))
Expand All @@ -143,6 +155,10 @@ def test_full_tosa_MI(self, test_tensor: Tuple):
def test_full_tosa_BI(self, test_tensor: Tuple):
self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor)

@parameterized.expand(FullLike.test_parameters)
def test_full_like_tosa_BI(self, test_tensor: Tuple):
self._test_full_tosa_BI_pipeline(self.FullLike(), test_tensor)

@parameterized.expand(AddVariableFull.test_parameters)
@pytest.mark.corstone_fvp
def test_full_u55_BI(self, test_tensor: Tuple):
Expand Down
2 changes: 1 addition & 1 deletion examples/arm/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ ethos_u_base_rev="24.08"

# tosa reference model
tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model"
tosa_reference_model_rev="v0.80.1"
tosa_reference_model_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a"

# vela
vela_repo_url="https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela"
Expand Down
Loading