Skip to content

Commit b042484

Browse files
Merge branch 'main' into add-int32-support-to-where-op
2 parents 2bbc032 + 994752e commit b042484

File tree

4 files changed

+120
-15
lines changed

4 files changed

+120
-15
lines changed

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ def fold_and_annotate_arg(
142142
f"Expected one of {dq_ops} dq_op, got {n.target}"
143143
)
144144

145-
if len(n.args) > 0:
146-
n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type]
147-
graph_module.graph.erase_node(n)
145+
node.replace_input_with(n, cast(Node, n.args[0]))
146+
if len(n.users) == 0:
147+
graph_module.graph.erase_node(n)
148148

149149
def call(self, graph_module: GraphModule) -> PassResult:
150150

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,21 +116,29 @@ def call(self, graph_module):
116116
or torch._export.utils.is_buffer(self.exported_program, input_node)
117117
for input_node in input_nodes
118118
)
119-
input_nodes_single_users = (
120-
len(input_node.users) == 1 for input_node in input_nodes
121-
)
119+
if not all(input_nodes_constant):
120+
continue
122121

123-
if all(input_nodes_constant) and all(input_nodes_single_users):
124-
try:
125-
did_fuse = self._fuse_nodes(node)
122+
try:
123+
did_fuse = self._fuse_nodes(node)
124+
if did_fuse:
125+
logger.debug(
126+
f"Fused constant op: {node.name} with placeholder inputs:"
127+
f"{[input_node.name for input_node in input_nodes]}"
128+
)
126129
modified |= did_fuse
127-
if did_fuse:
128-
graph_module.recompile() # Recompile needed to catch chains of constant ops
129-
input_nodes_to_delete.extend(input_nodes)
130-
except Exception as e:
131-
logger.warning(
132-
f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}"
130+
graph_module.recompile() # Recompile needed to catch chains of constant ops
131+
input_nodes_to_delete.extend(
132+
[
133+
input_node
134+
for input_node in input_nodes
135+
if len(input_node.users) == 1
136+
]
133137
)
138+
except Exception as e:
139+
logger.warning(
140+
f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}"
141+
)
134142

135143
if modified:
136144
graph_module.graph.eliminate_dead_code()

backends/arm/scripts/parse_test_names.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"hardswish.default",
1414
"linear.default",
1515
"maximum.default",
16+
"multihead_attention.default",
1617
"adaptive_avg_pool2d.default",
1718
"bitwise_right_shift.Tensor",
1819
"bitwise_left_shift.Tensor",
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import pytest
7+
import torch
8+
from executorch.backends.arm.test import common
9+
from executorch.backends.arm.test.tester.test_pipeline import (
10+
EthosU55PipelineBI,
11+
EthosU85PipelineBI,
12+
TosaPipelineBI,
13+
TosaPipelineMI,
14+
)
15+
16+
17+
class MultiheadAttention(torch.nn.MultiheadAttention):
18+
def forward(self, *args, **kwargs):
19+
return super().forward(*args, **kwargs)
20+
21+
22+
input_t1 = tuple[torch.Tensor, torch.nn.Module]
23+
test_suite = {
24+
# test_name, (x,), embed_dim, num_heads, batch_first
25+
"rand_2d": lambda: (
26+
(torch.rand(6, 3),),
27+
MultiheadAttention(embed_dim=3, num_heads=3, batch_first=True),
28+
),
29+
"randn_2d": lambda: (
30+
(torch.randn(2, 4),),
31+
MultiheadAttention(embed_dim=4, num_heads=2, batch_first=True),
32+
),
33+
"randn_3d": lambda: (
34+
(torch.randn(3, 2, 4),),
35+
MultiheadAttention(embed_dim=4, num_heads=2, batch_first=False),
36+
),
37+
}
38+
39+
40+
@common.parametrize(
41+
"test_data",
42+
test_suite,
43+
)
44+
def test_multihead_attention_tosa_MI(test_data: input_t1):
45+
test_data, module = test_data()
46+
pipeline = TosaPipelineMI(module, (*test_data, *test_data, *test_data), [], [])
47+
pipeline.run()
48+
49+
50+
@common.parametrize(
51+
"test_data",
52+
test_suite,
53+
)
54+
def test_multihead_attention_tosa_BI(test_data):
55+
test_data, module = test_data()
56+
pipeline = TosaPipelineBI(module, (*test_data, *test_data, *test_data), [], [])
57+
pipeline.run()
58+
59+
60+
@common.parametrize(
61+
"test_data",
62+
test_suite,
63+
)
64+
@pytest.mark.xfail(reason="MLETORCH-1102: Numerical issues on FVP")
65+
@common.XfailIfNoCorstone300
66+
def test_multihead_attention_u55_BI(test_data: input_t1):
67+
test_data, module = test_data()
68+
pipeline = EthosU55PipelineBI(
69+
module,
70+
(*test_data, *test_data, *test_data),
71+
[],
72+
[],
73+
use_to_edge_transform_and_lower=True,
74+
run_on_fvp=True,
75+
)
76+
pipeline.pop_stage("check_count.exir")
77+
pipeline.run()
78+
79+
80+
@common.parametrize(
81+
"test_data",
82+
test_suite,
83+
)
84+
@pytest.mark.xfail(reason="MLETORCH-1102: Numerical issues on FVP")
85+
@common.XfailIfNoCorstone320
86+
def test_multihead_attention_u85_BI(test_data: input_t1):
87+
test_data, module = test_data()
88+
pipeline = EthosU85PipelineBI(
89+
module,
90+
(*test_data, *test_data, *test_data),
91+
[],
92+
[],
93+
use_to_edge_transform_and_lower=True,
94+
run_on_fvp=True,
95+
)
96+
pipeline.run()

0 commit comments

Comments
 (0)