Skip to content

Commit 720cd76

Browse files
authored
Arm backend: add support for operator @ (#10749)
### Summary Support @ operator, an equivalent op to torch.matmul.
1 parent 8ee9f91 commit 720cd76

File tree

3 files changed

+154
-4
lines changed

3 files changed

+154
-4
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import itertools
10-
9+
import operator
1110
from typing import List
1211

1312
import torch
@@ -22,7 +21,7 @@
2221

2322
class AnnotateDecomposedMatmulPass(ExportPass):
2423
"""
25-
torch.matmul can be decomposed in many ways, for instance:
24+
torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance:
2625
dq -> matmul -> q can become
2726
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
2827
difficult. This helper function find all matmul partitions and annotate its
@@ -50,6 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
5049
graph_module.graph,
5150
[
5251
torch.matmul,
52+
operator.matmul,
5353
],
5454
None,
5555
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def _is_matmul_node_supported(
335335
graph_module.graph,
336336
[
337337
torch.matmul,
338+
operator.matmul,
338339
],
339340
None,
340341
)
@@ -385,7 +386,7 @@ def is_node_supported(
385386
):
386387
source_fn_stack: tuple[typing.Any] = node.meta.get("source_fn_stack", [])
387388
if len(source_fn_stack) > 0:
388-
if source_fn_stack[-1][1] in (torch.matmul,):
389+
if source_fn_stack[-1][1] in (torch.matmul, operator.matmul):
389390
return self._is_matmul_node_supported(submodules, node)
390391

391392
elif node.target in (exir_ops.edge.aten.max_pool2d_with_indices.default,):

backends/arm/test/ops/test_at.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
TosaPipelineBI,
12+
TosaPipelineMI,
13+
)
14+
15+
aten_op_mm = "torch.ops.aten.matmul.default"
16+
exir_op_mm = "executorch_exir_dialects_edge__ops_aten_matmul_default"
17+
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
18+
19+
20+
class AtMatMulSingleInput(torch.nn.Module):
21+
test_data_generators = {
22+
"rand_3d": lambda: (torch.rand(2, 5, 5),),
23+
"rand_4d": lambda: (torch.rand(1, 2, 5, 5),),
24+
}
25+
26+
def forward(self, x: torch.Tensor):
27+
return x @ x
28+
29+
30+
class AtMatMulDoubleInput(torch.nn.Module):
31+
test_data_generators = {
32+
"rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
33+
"rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
34+
}
35+
36+
def forward(self, x: torch.Tensor, y: torch.Tensor):
37+
return x @ y
38+
39+
40+
class AtMatMulMixedPattern1(torch.nn.Module):
41+
test_data_generators = {
42+
"rand_rand_rand_3d": lambda: (
43+
torch.rand(2, 5, 5),
44+
torch.rand(2, 5, 2),
45+
torch.rand(2, 2, 5),
46+
),
47+
"rand_rand_rand_4d": lambda: (
48+
torch.rand(1, 2, 5, 5),
49+
torch.rand(1, 2, 5, 2),
50+
torch.rand(1, 2, 2, 5),
51+
),
52+
}
53+
54+
def forward(self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor):
55+
y1 = torch.matmul(x1, x1)
56+
y2 = torch.matmul(x2, x3)
57+
return y1 + y2
58+
59+
60+
class AtMatMulMixedPattern2(torch.nn.Module):
61+
test_data_generators = {
62+
"rand_rand_rand_3d": lambda: (
63+
torch.rand(2, 5, 5),
64+
torch.rand(2, 5, 2),
65+
torch.rand(2, 2, 5),
66+
),
67+
"rand_rand_rand_4d": lambda: (
68+
torch.rand(1, 2, 5, 5),
69+
torch.rand(1, 2, 5, 2),
70+
torch.rand(1, 2, 2, 5),
71+
),
72+
}
73+
74+
def forward(self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor):
75+
y1 = torch.matmul(x1, x1)
76+
y2 = torch.matmul(x2, x3)
77+
return y1 @ y2
78+
79+
80+
@common.parametrize("test_data", AtMatMulSingleInput.test_data_generators)
81+
def test_atmatmul_single_input_tosa_MI(test_data: input_t1):
82+
pipeline = TosaPipelineMI[input_t1](
83+
AtMatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm
84+
)
85+
pipeline.run()
86+
87+
88+
@common.parametrize("test_data", AtMatMulDoubleInput.test_data_generators)
89+
def test_atmatmul_double_input_tosa_MI(test_data: input_t1):
90+
pipeline = TosaPipelineMI[input_t1](
91+
AtMatMulDoubleInput(), test_data(), aten_op_mm, exir_op_mm
92+
)
93+
pipeline.run()
94+
95+
96+
@common.parametrize("test_data", AtMatMulMixedPattern1.test_data_generators)
97+
def test_atmatmul_mixed_pattern1_tosa_MI(test_data: input_t1):
98+
pipeline = TosaPipelineMI[input_t1](
99+
AtMatMulMixedPattern1(), test_data(), aten_op_mm, exir_op_mm
100+
)
101+
pipeline.run()
102+
103+
104+
@common.parametrize("test_data", AtMatMulMixedPattern2.test_data_generators)
105+
def test_atmatmul_mixed_pattern2_tosa_MI(test_data: input_t1):
106+
pipeline = TosaPipelineMI[input_t1](
107+
AtMatMulMixedPattern2(), test_data(), aten_op_mm, exir_op_mm
108+
)
109+
pipeline.run()
110+
111+
112+
@common.parametrize("test_data", AtMatMulSingleInput.test_data_generators)
113+
def test_atmatmul_single_input_tosa_BI(test_data: input_t1):
114+
pipeline = TosaPipelineBI[input_t1](
115+
AtMatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm
116+
)
117+
pipeline.run()
118+
119+
120+
@common.parametrize("test_data", AtMatMulDoubleInput.test_data_generators)
121+
def test_atmatmul_double_input_tosa_BI(test_data: input_t1):
122+
pipeline = TosaPipelineBI[input_t1](
123+
AtMatMulDoubleInput(), test_data(), aten_op_mm, exir_op_mm
124+
)
125+
pipeline.run()
126+
127+
128+
@common.parametrize("test_data", AtMatMulMixedPattern1.test_data_generators)
129+
def test_atmatmul_mixed_pattern1_tosa_BI(test_data: input_t1):
130+
pipeline = TosaPipelineBI[input_t1](
131+
AtMatMulMixedPattern1(),
132+
test_data(),
133+
aten_op_mm,
134+
exir_op_mm,
135+
qtol=1,
136+
)
137+
pipeline.run()
138+
139+
140+
@common.parametrize("test_data", AtMatMulMixedPattern2.test_data_generators)
141+
def test_atmatmul_mixed_pattern2_tosa_BI(test_data: input_t1):
142+
pipeline = TosaPipelineBI[input_t1](
143+
AtMatMulMixedPattern2(),
144+
test_data(),
145+
aten_op_mm,
146+
exir_op_mm,
147+
qtol=1,
148+
)
149+
pipeline.run()

0 commit comments

Comments
 (0)