Skip to content

Commit 6f015f6

Browse files
Arm backend: Improve broadcasting (#10940)
Ethos-U55 only supports broadcasting of one argument. This patch introduces a pass which will insert repeat ops to make sure that only one input needs to be broadcasted. The pass is only applied for Ethos-U55. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 12af535 commit 6f015f6

File tree

8 files changed

+138
-0
lines changed

8 files changed

+138
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
1010
from .arm_pass import ArmPass # noqa
11+
from .broadcast_args_pass import BroadcastArgsPass # noqa
1112
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1213
from .cast_to_int32_pass import CastToInt32Pass # noqa
1314
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from executorch.backends.arm._passes import (
1111
AnnotateChannelsLastDimOrder,
1212
AnnotateDecomposedMatmulPass,
13+
BroadcastArgsPass,
1314
CastInt64BuffersToInt32Pass,
1415
CastToInt32Pass,
1516
ComputeConstantOpsAOT,
@@ -104,6 +105,8 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
104105
self.add_pass(RetraceFoldedDtypesPass())
105106
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
106107
self.add_pass(MatchArgRanksPass(exported_program))
108+
if self.tosa_spec.is_U55_subset:
109+
self.add_pass(BroadcastArgsPass())
107110
self.add_pass(ComputeConstantOpsAOT(exported_program))
108111

109112
self.add_pass(RemoveClonePass())
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm._passes import ArmPass
7+
8+
from executorch.backends.arm._passes.arm_pass_utils import (
9+
create_node,
10+
get_first_fake_tensor,
11+
)
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
15+
from executorch.exir.pass_base import PassResult
16+
from torch.fx import GraphModule, Node
17+
18+
19+
class BroadcastArgsPass(ArmPass):
20+
"""
21+
Pass to manually broadcast arguments by inserting repeats.
22+
This is done when more than one arg needs broadcasting.
23+
"""
24+
25+
targeted_ops = {
26+
exir_ops.edge.aten.add.Tensor,
27+
exir_ops.edge.aten.sub.Tensor,
28+
# mul is indirectly targeting div as div is decompsed to reciprocal + mul
29+
exir_ops.edge.aten.mul.Tensor,
30+
}
31+
32+
def call(self, graph_module: GraphModule) -> PassResult:
33+
for node in graph_module.graph.nodes:
34+
if node.op != "call_function" or node.target not in self.targeted_ops:
35+
continue
36+
37+
output_shape = get_first_fake_tensor(node).shape
38+
nbr_of_broacasts = 0
39+
for arg in node.args:
40+
if not isinstance(arg, Node):
41+
continue
42+
43+
shape = get_first_fake_tensor(arg).shape
44+
if shape != output_shape:
45+
nbr_of_broacasts += 1
46+
if nbr_of_broacasts > 1:
47+
multiples = [
48+
int(output_shape[d] / shape[d])
49+
for d in range(len(output_shape))
50+
]
51+
with graph_module.graph.inserting_before(node):
52+
repeat = create_node(
53+
graph_module.graph,
54+
exir_ops.edge.aten.repeat.default,
55+
args=(arg, multiples),
56+
kwargs={},
57+
from_node=node,
58+
)
59+
node.replace_input_with(arg, repeat)
60+
61+
graph_module.recompile()
62+
graph_module = super().call(graph_module).graph_module
63+
return PassResult(graph_module, True)

backends/arm/test/ops/test_add.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
6060
10000 * torch.randn(1, 1, 4, 4),
6161
torch.randn(1, 1, 4, 1),
6262
),
63+
"4d_randn_1_mutltiple_broadcasts": lambda: (
64+
torch.randn(1, 4, 4, 1),
65+
torch.ones(1, 1, 4, 4),
66+
),
6367
}
6468

6569

backends/arm/test/ops/test_div.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@
6666
torch.rand(5, 10, 25, 20) + 1,
6767
None,
6868
),
69+
"op_div_rank4_randn_mutltiple_broadcasts": lambda: (
70+
torch.randn(1, 4, 4, 1),
71+
torch.randn(1, 1, 4, 4),
72+
None,
73+
),
6974
}
7075

7176

backends/arm/test/ops/test_mul.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
200 * torch.randn(1, 10, 25, 20),
5252
torch.rand(1, 10, 25, 1),
5353
),
54+
"op_mul_rank4_randn_mutltiple_broadcasts": lambda: (
55+
torch.randn(1, 4, 4, 1),
56+
torch.randn(1, 1, 4, 4),
57+
),
5458
}
5559

5660

backends/arm/test/ops/test_sub.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
"rand_3D_4x4x4": lambda: (torch.rand(4, 2, 2), torch.rand(4, 2, 2)),
3939
"rand_4D_2x2x4x4": lambda: (torch.rand(2, 2, 4, 4), torch.rand(2, 2, 4, 4)),
4040
"zeros": lambda: (torch.rand(4, 4), torch.zeros(4, 4)),
41+
"randn_4D_mutltiple_broadcasts": lambda: (
42+
torch.randn(1, 4, 4, 1),
43+
torch.randn(1, 1, 4, 4),
44+
),
4145
}
4246
fvp_sub2_xfails = {"rand_4D_2x2x4x4": "MLETORCH-517 : Multiple batches not supported"}
4347

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import operator
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm._passes import BroadcastArgsPass
11+
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.Tensor] # Input x
16+
17+
18+
class NeedsMultipleBroadcastsModel(torch.nn.Module):
19+
test_data = (torch.rand(1, 10), torch.rand(10, 1))
20+
21+
def __init__(self, op: operator):
22+
self.op = op
23+
super().__init__()
24+
25+
def forward(self, x: torch.Tensor, y: torch.Tensor):
26+
return self.op(x, y)
27+
28+
29+
modules = {
30+
"add": NeedsMultipleBroadcastsModel(operator.add),
31+
"sub": NeedsMultipleBroadcastsModel(operator.sub),
32+
"mul": NeedsMultipleBroadcastsModel(operator.mul),
33+
"div": NeedsMultipleBroadcastsModel(operator.truediv),
34+
}
35+
36+
37+
@common.parametrize("module", modules)
38+
def test_multiple_broacasts_model(module: NeedsMultipleBroadcastsModel):
39+
test_data = module.test_data
40+
ops_not_before_pass = [
41+
"executorch_exir_dialects_edge__ops_aten_repeat_default",
42+
]
43+
ops_after_pass = {
44+
"executorch_exir_dialects_edge__ops_aten_repeat_default": 1,
45+
}
46+
pipeline = PassPipeline[input_t](
47+
module,
48+
test_data,
49+
quantize=True,
50+
ops_not_before_pass=ops_not_before_pass,
51+
ops_after_pass=ops_after_pass,
52+
pass_list=[BroadcastArgsPass],
53+
)
54+
pipeline.run()

0 commit comments

Comments
 (0)