Skip to content

Commit 6260921

Browse files
Arm backend: Add support for single input matmul (#10654)
### Summary AnnotateDecomposedMatmul makes sure that a decomposed matmul will two dq-nodes before and a q-node after it's mm/bmm-node. Previously it assumed that the partition always had two input nodes (two dq-nodes), but this is not the case for a single input matmul, e.g. torch.matmul(x, x). In such a case we must copy the dq-node and insert it before the mm/bmm's two inputs. ``` Before pass: -> expand -> view -> / \ x -> dq bmm -> view -> q \ / -> expand -> view -> After pass: -> expand -> view -> dq / \ x bmm -> q -> view \ / -> expand -> view -> dq ``` Signed-off-by: Oscar Andersson <[email protected]>
1 parent e88aafc commit 6260921

File tree

3 files changed

+206
-24
lines changed

3 files changed

+206
-24
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
7070
if quantized_input:
7171
matmul_args = matmul_node.all_input_nodes
7272
for node in matmul_args:
73+
# Find the dq-node connected to this mm/bmm arg
7374
input_node = self._match_partition_to_node(
7475
node, partition.input_nodes
7576
)
76-
77-
# Remove partition input dq-node
78-
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
79-
graph_module.graph.erase_node(input_node)
8077
input_node_qargs = QuantArgs.from_operator(
8178
input_node.target, input_node.args
8279
)
83-
80+
# Insert new dq-node just before the mm/bmm with input_node's qparams
8481
with graph_module.graph.inserting_before(matmul_node):
8582
# Create new dq-node before matmul
8683
dq_node = create_node(
@@ -90,6 +87,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
9087
dq_node.args = (node, *input_node_qargs)
9188
matmul_node.replace_input_with(node, dq_node)
9289

90+
for partition_input in partition.input_nodes:
91+
# Remove partition input dq-node
92+
partition_input.replace_all_uses_with(
93+
partition_input.all_input_nodes[0]
94+
)
95+
graph_module.graph.erase_node(partition_input)
96+
9397
partition_output = list(partition.output_nodes[0].users)[0]
9498
quantized_output = partition_output.target == q_op
9599
if quantized_output:

backends/arm/test/ops/test_bmm.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,6 @@ class BMM(torch.nn.Module):
3232
def forward(self, x, y):
3333
return torch.bmm(x, y)
3434

35-
class MatMul(torch.nn.Module):
36-
test_data_generators = [
37-
lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
38-
lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
39-
]
40-
41-
def forward(self, x, y):
42-
return torch.matmul(x, y)
43-
4435
class BMMSingleInput(torch.nn.Module):
4536
test_data_generators = [
4637
lambda: (torch.rand(20, 3, 3),),
@@ -129,16 +120,6 @@ def test_bmm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]
129120
test_data = test_data_generator()
130121
self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data)
131122

132-
@parameterized.expand(MatMul.test_data_generators)
133-
def test_matmul_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
134-
test_data = test_data_generator()
135-
self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data)
136-
137-
@parameterized.expand(MatMul.test_data_generators)
138-
def test_matmul_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
139-
test_data = test_data_generator()
140-
self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data)
141-
142123
@parameterized.expand(BMM.test_data_generators)
143124
def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
144125
test_data = test_data_generator()

backends/arm/test/ops/test_matmul.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU55PipelineBI,
12+
EthosU85PipelineBI,
13+
TosaPipelineBI,
14+
TosaPipelineMI,
15+
)
16+
17+
aten_op_mm = "torch.ops.aten.matmul.default"
18+
exir_op_mm = "executorch_exir_dialects_edge__ops_aten_matmul_default"
19+
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
20+
21+
22+
class MatMul(torch.nn.Module):
23+
test_data_generators = {
24+
"rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
25+
"rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
26+
}
27+
28+
def forward(self, x: torch.Tensor, y: torch.Tensor):
29+
return torch.matmul(x, y)
30+
31+
32+
class MatMulSingleInput(torch.nn.Module):
33+
test_data_generators = {
34+
"rand_3d": lambda: (torch.rand(2, 5, 5),),
35+
"rand_4d": lambda: (torch.rand(1, 2, 5, 5),),
36+
}
37+
38+
def forward(self, x: torch.Tensor):
39+
return torch.matmul(x, x)
40+
41+
42+
class MatMulCombo(torch.nn.Module):
43+
test_data_generators = {
44+
"rand_rand_rand_3d": lambda: (
45+
torch.rand(2, 5, 5),
46+
torch.rand(2, 5, 2),
47+
torch.rand(2, 2, 5),
48+
),
49+
"rand_rand_rand_4d": lambda: (
50+
torch.rand(1, 2, 5, 5),
51+
torch.rand(1, 2, 5, 2),
52+
torch.rand(1, 2, 2, 5),
53+
),
54+
}
55+
56+
def forward(self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor):
57+
y1 = torch.matmul(x1, x1)
58+
y2 = torch.matmul(x2, x3)
59+
return y1 + y2
60+
61+
62+
@common.parametrize("test_data", MatMul.test_data_generators)
63+
def test_matmul_tosa_MI(test_data: input_t1):
64+
pipeline = TosaPipelineMI[input_t1](MatMul(), test_data(), aten_op_mm, exir_op_mm)
65+
pipeline.run()
66+
67+
68+
@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
69+
def test_matmul_single_input_tosa_MI(test_data: input_t1):
70+
pipeline = TosaPipelineMI[input_t1](
71+
MatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm
72+
)
73+
pipeline.run()
74+
75+
76+
@common.parametrize("test_data", MatMulCombo.test_data_generators)
77+
def test_matmul_combo_tosa_MI(test_data: input_t1):
78+
pipeline = TosaPipelineMI[input_t1](
79+
MatMulCombo(), test_data(), aten_op_mm, exir_op_mm
80+
)
81+
pipeline.run()
82+
83+
84+
@common.parametrize("test_data", MatMul.test_data_generators)
85+
def test_matmul_tosa_BI(test_data: input_t1):
86+
pipeline = TosaPipelineBI[input_t1](
87+
MatMul(), test_data(), aten_op_mm, exir_op_mm, qtol=1
88+
)
89+
pipeline.run()
90+
91+
92+
@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
93+
def test_matmul_single_input_tosa_BI(test_data: input_t1):
94+
pipeline = TosaPipelineMI[input_t1](
95+
MatMulSingleInput(),
96+
test_data(),
97+
aten_op_mm,
98+
exir_op_mm,
99+
qtol=1,
100+
)
101+
pipeline.run()
102+
103+
104+
@common.parametrize("test_data", MatMulCombo.test_data_generators)
105+
def test_matmul_combo_tosa_BI(test_data: input_t1):
106+
pipeline = TosaPipelineBI[input_t1](
107+
MatMulCombo(),
108+
test_data(),
109+
aten_op_mm,
110+
exir_op_mm,
111+
qtol=1,
112+
)
113+
pipeline.run()
114+
115+
116+
@common.parametrize("test_data", MatMul.test_data_generators)
117+
@common.XfailIfNoCorstone300
118+
def test_matmul_u55_BI(test_data: input_t1):
119+
pipeline = EthosU55PipelineBI[input_t1](
120+
MatMul(),
121+
test_data(),
122+
aten_op_mm,
123+
exir_op_mm,
124+
run_on_fvp=True,
125+
use_to_edge_transform_and_lower=True,
126+
)
127+
pipeline.run()
128+
129+
130+
@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
131+
@common.XfailIfNoCorstone300
132+
def test_matmul_single_input_u55_BI(test_data: input_t1):
133+
pipeline = EthosU55PipelineBI[input_t1](
134+
MatMulSingleInput(),
135+
test_data(),
136+
aten_op_mm,
137+
exir_op_mm,
138+
run_on_fvp=True,
139+
use_to_edge_transform_and_lower=True,
140+
)
141+
pipeline.run()
142+
143+
144+
@common.parametrize("test_data", MatMulCombo.test_data_generators)
145+
@common.XfailIfNoCorstone300
146+
def test_matmul_combo_u55_BI(test_data: input_t1):
147+
pipeline = EthosU55PipelineBI[input_t1](
148+
MatMulCombo(),
149+
test_data(),
150+
aten_op_mm,
151+
exir_op_mm,
152+
run_on_fvp=True,
153+
use_to_edge_transform_and_lower=True,
154+
)
155+
pipeline.run()
156+
157+
158+
@common.parametrize("test_data", MatMul.test_data_generators)
159+
@common.XfailIfNoCorstone320
160+
def test_matmul_u85_BI(test_data: input_t1):
161+
pipeline = EthosU85PipelineBI[input_t1](
162+
MatMul(),
163+
test_data(),
164+
aten_op_mm,
165+
exir_op_mm,
166+
run_on_fvp=True,
167+
use_to_edge_transform_and_lower=True,
168+
)
169+
pipeline.run()
170+
171+
172+
@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
173+
@common.XfailIfNoCorstone320
174+
def test_matmul_single_input_u85_BI(test_data: input_t1):
175+
pipeline = EthosU85PipelineBI[input_t1](
176+
MatMulSingleInput(),
177+
test_data(),
178+
aten_op_mm,
179+
exir_op_mm,
180+
run_on_fvp=True,
181+
use_to_edge_transform_and_lower=True,
182+
)
183+
pipeline.run()
184+
185+
186+
@common.parametrize("test_data", MatMulCombo.test_data_generators)
187+
@common.XfailIfNoCorstone320
188+
def test_matmul_combo_u85_BI(test_data: input_t1):
189+
pipeline = EthosU85PipelineBI[input_t1](
190+
MatMulCombo(),
191+
test_data(),
192+
aten_op_mm,
193+
exir_op_mm,
194+
run_on_fvp=True,
195+
use_to_edge_transform_and_lower=True,
196+
)
197+
pipeline.run()

0 commit comments

Comments
 (0)