Skip to content

Commit 05abf89

Browse files
Arm backend: Add MM to BMM pass (#7341)
aten.mm does not support input or output tensors of rank 3, which is required by TOSA for the MM operation. Therefore, create a pass that converts any MM nodes to BMM (which always has rank 3). The pass also unsqueezes tensors of rank 2 to rank 3. As a result of the new pass, op_mm.py is no longer required and has been removed. Change-Id: I8459dd73bb366452b5139b48a5724c300b2d5a26 Signed-off-by: Sebastian Larsson <[email protected]>
1 parent b16271c commit 05abf89

File tree

5 files changed

+103
-156
lines changed

5 files changed

+103
-156
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -36,7 +36,6 @@ def call(self, graph_module: GraphModule) -> PassResult:
3636
itertools.chain.from_iterable(matmul_partitions.values())
3737
)
3838
matmul_targets = {
39-
exir_ops.edge.aten.mm.default,
4039
exir_ops.edge.aten.bmm.default,
4140
}
4241
for partition in matmul_partitions:

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
33
# All rights reserved.
44
#
55
# This source code is licensed under the BSD-style license found in the
@@ -45,6 +45,7 @@
4545
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
4646
ConvertMeanDimToAveragePool,
4747
)
48+
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
4849
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
4950
from executorch.backends.arm._passes.scalars_to_attribute_pass import (
5051
ScalarsToAttributePass,
@@ -79,6 +80,7 @@ def transform_to_backend_pipeline(
7980
self.add_pass(ConvertMeanDimToAveragePool())
8081
self.add_pass(DecomposeMeanDimPass())
8182
self.add_pass(ConvertSplitToSlicePass())
83+
self.add_pass(ConvertMmToBmmPass())
8284
# TODO MLETORCH-558
8385
self.add_pass(AnnotateDecomposedMatmulPass())
8486
self.add_pass(QuantizeFullArgument())
@@ -99,7 +101,6 @@ def transform_to_backend_pipeline(
99101
exir_ops.edge.aten.hardtanh.default,
100102
exir_ops.edge.aten.log.default,
101103
exir_ops.edge.aten.max_pool2d.default,
102-
exir_ops.edge.aten.mm.default,
103104
exir_ops.edge.aten.mul.Tensor,
104105
exir_ops.edge.aten.permute_copy.default,
105106
exir_ops.edge.aten.reciprocal.default,
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.arm._passes.arm_pass_utils import (
9+
create_node,
10+
get_first_fake_tensor,
11+
insert_q_dq_pair,
12+
)
13+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
from torch.fx import Node
17+
18+
19+
class ConvertMmToBmmPass(ExportPass):
20+
"""
21+
This pass converts a MM node to a BMM one and turns input and output tensors
22+
from rank 2 to rank 3. The TOSA specification requires rank 3. The graph is
23+
modified to do the following:
24+
1) Unsqueeze input tensors to rank 3.
25+
2) Convert MM node to BMM.
26+
3) Squeeze output tensor to rank 2.
27+
"""
28+
29+
def call(self, graph_module: torch.fx.GraphModule):
30+
modified_graph = False
31+
graph = graph_module.graph
32+
node_list = graph.find_nodes(
33+
op="call_function", target=exir_ops.edge.aten.mm.default
34+
)
35+
for node in node_list:
36+
# Unsqueeze input tensors to rank 3
37+
for input_node in node.args:
38+
if not isinstance(input_node, Node):
39+
continue
40+
41+
shape = get_first_fake_tensor(input_node).shape
42+
rank = len(shape)
43+
if rank != 2:
44+
raise RuntimeError(f"Input tensor has rank {rank}, must be 2")
45+
46+
with graph.inserting_before(node):
47+
unsqueeze_before = create_node(
48+
graph, exir_ops.edge.aten.unsqueeze_copy.default
49+
)
50+
unsqueeze_before.args = (
51+
input_node, # Input is node's original input
52+
0,
53+
)
54+
node.replace_input_with(input_node, unsqueeze_before)
55+
56+
# If Quantized we must insert unsqueeze --> q --> dq --> node
57+
if input_node.target == dq_op:
58+
q_params = input_node.args[1:]
59+
insert_q_dq_pair(graph, unsqueeze_before, q_params)
60+
61+
# Replace mm node with bmm
62+
with graph.inserting_before(node):
63+
bmm_node = create_node(
64+
graph,
65+
exir_ops.edge.aten.bmm.default,
66+
)
67+
bmm_node.args = node.args
68+
node.replace_all_uses_with(bmm_node)
69+
graph.erase_node(node)
70+
71+
# Unsqueeze output tensor to rank 3
72+
with graph.inserting_after(bmm_node):
73+
squeeze_after = create_node(
74+
graph,
75+
exir_ops.edge.aten.squeeze_copy.dims,
76+
)
77+
squeeze_after.args = (
78+
bmm_node,
79+
[0],
80+
)
81+
original_users = [
82+
user for user in bmm_node.users if user != squeeze_after
83+
]
84+
for user in original_users:
85+
user.replace_input_with(bmm_node, squeeze_after)
86+
87+
# If quantized, insert mm --> q --> dq --> squeeze
88+
if all(original_user.target == q_op for original_user in original_users):
89+
q_params = original_users[0].args[1:]
90+
insert_q_dq_pair(graph, bmm_node, q_params)
91+
92+
modified_graph = True
93+
94+
if modified_graph:
95+
graph_module.recompile()
96+
graph_module = super().call(graph_module).graph_module
97+
98+
return PassResult(graph_module, modified_graph)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -22,7 +22,6 @@
2222
op_max,
2323
op_max_pool2d,
2424
op_min,
25-
op_mm,
2625
op_mul,
2726
op_permute,
2827
op_quant,

backends/arm/operators/op_mm.py

Lines changed: 0 additions & 150 deletions
This file was deleted.

0 commit comments

Comments
 (0)