Skip to content

Commit 66b2f73

Browse files
authored
Implement bmm op for Arm backend
Differential Revision: D61852906 Pull Request resolved: #4926
1 parent eeb52d5 commit 66b2f73

File tree

5 files changed

+220
-1
lines changed

5 files changed

+220
-1
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4040
exir_ops.edge.aten.addmm.default,
4141
exir_ops.edge.aten.expand_copy.default,
4242
exir_ops.edge.aten.cat.default,
43+
exir_ops.edge.aten.bmm.default,
4344
exir_ops.edge.aten.permute_copy.default,
4445
exir_ops.edge.aten.hardtanh.default,
4546
exir_ops.edge.aten.convolution.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
op_addmm,
1010
op_avg_pool2d,
1111
op_batch_norm,
12+
op_bmm,
1213
op_cat,
1314
op_conv2d,
1415
op_dequant,

backends/arm/operators/op_bmm.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import List
7+
8+
import serializer.tosa_serializer as ts
9+
import torch.fx
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
16+
from executorch.backends.arm.tosa_utils import get_two_inputs
17+
from serializer.tosa_serializer import TosaOp
18+
19+
20+
@register_node_visitor
21+
class BMMVisitor(NodeVisitor):
22+
target = "aten.bmm.default"
23+
24+
def __init__(self, *args):
25+
super().__init__(*args)
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
tosa_graph: ts.TosaSerializer,
31+
inputs: List[TosaArg],
32+
output: TosaArg,
33+
is_quant_node: bool,
34+
) -> None:
35+
input0, input1 = get_two_inputs(node)
36+
37+
# aten.bmm maps directly to MATMUL
38+
# NOTE: For now, only INT8 & FP32 is supported
39+
40+
# For INT8, we need to get the zero points and add an intermediate tensor
41+
# for a later rescale.
42+
if is_quant_node:
43+
input0_zp = get_quant_node_args(input0).zp
44+
input1_zp = get_quant_node_args(input1).zp
45+
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
46+
bmm_output_name = bmm_result.name
47+
else:
48+
input0_zp, input1_zp = 0, 0
49+
bmm_output_name = output.name
50+
51+
# Add the MATMUL to the TOSA graph.
52+
attr = ts.TosaSerializerAttribute()
53+
attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp)
54+
55+
tosa_graph.addOperator(
56+
TosaOp.Op().MATMUL,
57+
[input0.name, input1.name],
58+
[bmm_output_name],
59+
attr,
60+
)
61+
62+
# As INT8 accumulates into INT32, we need to rescale it back to INT8
63+
if is_quant_node:
64+
input0_q_params = get_quant_node_args(input0)
65+
input1_q_params = get_quant_node_args(input1)
66+
output_q_params = get_quant_node_args(list(node.users)[0])
67+
68+
final_output_scale = (
69+
input0_q_params.scale * input1_q_params.scale
70+
) / output_q_params.scale
71+
72+
build_rescale(
73+
tosa_fb=tosa_graph,
74+
scale=final_output_scale,
75+
input_node=bmm_result,
76+
output_name=output.name,
77+
output_type=ts.DType.INT8,
78+
output_shape=bmm_result.shape,
79+
input_zp=0,
80+
output_zp=output_q_params.zp,
81+
is_double_round=False,
82+
)

backends/arm/quantizer/quantization_annotation/mm_annotator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _annotate_mm(
2222
quantization_config: QuantizationConfig,
2323
filter_fn: Optional[Callable[[Node], bool]] = None,
2424
) -> Optional[List[List[Node]]]:
25-
mm_partitions = get_source_partitions(gm.graph, [torch.mm], filter_fn)
25+
mm_partitions = get_source_partitions(gm.graph, [torch.mm, torch.bmm], filter_fn)
2626
mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values()))
2727
annotated_partitions = []
2828
for mm_partition in mm_partitions:

backends/arm/test/ops/test_bmm.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Tuple
10+
11+
import torch
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
14+
from parameterized import parameterized
15+
16+
torch.manual_seed(1)
17+
18+
19+
class TestBMM(unittest.TestCase):
20+
"""Tests Batch MatMul"""
21+
22+
class BMM(torch.nn.Module):
23+
test_parameters = [
24+
(torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
25+
(torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
26+
(torch.ones(1, 55, 3), torch.ones(1, 3, 44)),
27+
(10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)),
28+
(-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)),
29+
]
30+
31+
def forward(self, x, y):
32+
return torch.bmm(x, y)
33+
34+
class BMMSingleInput(torch.nn.Module):
35+
test_parameters = [
36+
(torch.rand(20, 3, 3),),
37+
(torch.ones(2, 128, 128),),
38+
(10000 * torch.randn(4, 25, 25),),
39+
(5 + 5 * torch.randn(3, 64, 64),),
40+
]
41+
42+
def forward(self, x):
43+
return torch.bmm(x, x)
44+
45+
def _test_bmm_tosa_MI_pipeline(
46+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, ...]
47+
):
48+
(
49+
ArmTester(
50+
module,
51+
example_inputs=test_data,
52+
compile_spec=common.get_tosa_compile_spec(),
53+
)
54+
.export()
55+
.check_count({"torch.ops.aten.bmm.default": 1})
56+
.check_not(["torch.ops.quantized_decomposed"])
57+
.to_edge()
58+
.partition()
59+
.check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"])
60+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
61+
.to_executorch()
62+
.run_method_and_compare_outputs(inputs=test_data)
63+
)
64+
65+
def _test_bmm_tosa_BI_pipeline(
66+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, ...]
67+
):
68+
(
69+
ArmTester(
70+
module,
71+
example_inputs=test_data,
72+
compile_spec=common.get_tosa_compile_spec(),
73+
)
74+
.quantize()
75+
.export()
76+
.check_count({"torch.ops.aten.bmm.default": 1})
77+
.check(["torch.ops.quantized_decomposed"])
78+
.to_edge()
79+
.partition()
80+
.check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"])
81+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
82+
.to_executorch()
83+
.run_method_and_compare_outputs(inputs=test_data)
84+
)
85+
86+
def _test_bmm_u55_BI_pipeline(
87+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, ...]
88+
):
89+
(
90+
ArmTester(
91+
module,
92+
example_inputs=test_data,
93+
compile_spec=common.get_u55_compile_spec(),
94+
)
95+
.quantize()
96+
.export()
97+
.check_count({"torch.ops.aten.bmm.default": 1})
98+
.check(["torch.ops.quantized_decomposed"])
99+
.to_edge()
100+
.partition()
101+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
102+
.to_executorch()
103+
)
104+
105+
@parameterized.expand(BMM.test_parameters)
106+
def test_bmm_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
107+
test_data = (operand1, operand2)
108+
self._test_bmm_tosa_MI_pipeline(self.BMM(), test_data)
109+
110+
@parameterized.expand(BMMSingleInput.test_parameters)
111+
def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor):
112+
test_data = (operand1,)
113+
self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data)
114+
115+
@parameterized.expand(BMM.test_parameters)
116+
def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
117+
test_data = (operand1, operand2)
118+
self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)
119+
120+
@parameterized.expand(BMMSingleInput.test_parameters)
121+
def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor):
122+
test_data = (operand1,)
123+
self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data)
124+
125+
@parameterized.expand(BMM.test_parameters)
126+
def test_bmm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
127+
test_data = (operand1, operand2)
128+
self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)
129+
130+
# Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
131+
@parameterized.expand(BMMSingleInput.test_parameters)
132+
@unittest.expectedFailure
133+
def test_bmm_single_input_u55_BI(self, operand1: torch.Tensor):
134+
test_data = (operand1,)
135+
self._test_bmm_u55_BI_pipeline(self.BMMSingleInput(), test_data)

0 commit comments

Comments
 (0)