Skip to content

Commit 48b4304

Browse files
authored
Implement mm op for Arm backend
Differential Revision: D61240788 Pull Request resolved: #4628
1 parent 938748b commit 48b4304

File tree

8 files changed

+330
-1
lines changed

8 files changed

+330
-1
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4747
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
4848
exir_ops.edge.aten.avg_pool2d.default,
4949
exir_ops.edge.aten.sigmoid.default,
50+
exir_ops.edge.aten.mm.default,
5051
exir_ops.edge.aten.repeat.default,
5152
exir_ops.edge.aten._softmax.default,
5253
exir_ops.edge.aten.slice_copy.Tensor,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
op_get_item,
1717
op_hardtanh,
1818
op_mean_dim,
19+
op_mm,
1920
op_permute,
2021
op_quant,
2122
op_repeat,

backends/arm/operators/op_mm.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import List
7+
8+
import serializer.tosa_serializer as ts
9+
import torch
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
16+
from executorch.backends.arm.tosa_utils import (
17+
build_reshape,
18+
expand_dims,
19+
get_two_inputs,
20+
)
21+
from serializer.tosa_serializer import TosaOp
22+
23+
24+
@register_node_visitor
25+
class MMVisitor(NodeVisitor):
26+
target = "aten.mm.default"
27+
28+
def __init__(self, *args):
29+
super().__init__(*args)
30+
31+
def define_node(
32+
self,
33+
node: torch.fx.Node,
34+
tosa_graph: ts.TosaSerializer,
35+
inputs: List[TosaArg],
36+
output: TosaArg,
37+
is_quant_node: bool,
38+
) -> None:
39+
input0, input1 = get_two_inputs(node)
40+
41+
# For atem.mm, the two inputs are of rank 2
42+
# For TOSA it needs to be rank 3
43+
# So they need to be reshaped from (H, W) to (1, H, W)
44+
# NOTE: For now, only INT8 & FP32 is supported
45+
reshape_dtype = ts.DType.INT8 if is_quant_node else ts.DType.FP32
46+
input0_reshaped = expand_dims(tosa_graph, inputs[0], reshape_dtype, 0)
47+
input1_reshaped = expand_dims(tosa_graph, inputs[1], reshape_dtype, 0)
48+
49+
# The output also needs to be rank 3
50+
output_new_shape = (1, output.shape[0], output.shape[1])
51+
52+
# For INT8, we need to get the zero point, otherwise it is 0
53+
input0_zp, input1_zp = 0, 0
54+
if is_quant_node:
55+
input0_zp = get_quant_node_args(input0).zp
56+
input1_zp = get_quant_node_args(input1).zp
57+
58+
mat_mul_result = tosa_graph.addIntermediate(
59+
output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype
60+
)
61+
62+
attr = ts.TosaSerializerAttribute()
63+
attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp)
64+
65+
tosa_graph.addOperator(
66+
TosaOp.Op().MATMUL,
67+
[input0_reshaped.name, input1_reshaped.name],
68+
[mat_mul_result.name],
69+
attr,
70+
)
71+
72+
if is_quant_node:
73+
reshape_intermediate = tosa_graph.addIntermediate(
74+
output.shape, ts.DType.INT32
75+
)
76+
reshape_output_name = reshape_intermediate.name
77+
else:
78+
reshape_output_name = output.name
79+
80+
# Reshape the final output back to rank 2
81+
build_reshape(
82+
tosa_graph, mat_mul_result.name, output.shape, reshape_output_name
83+
)
84+
85+
# As INT8 accumulates into INT32, we need to rescale it back to INT8
86+
if is_quant_node:
87+
input0_q_params = get_quant_node_args(input0)
88+
input1_q_params = get_quant_node_args(input1)
89+
output_q_params = get_quant_node_args(list(node.users)[0])
90+
91+
final_output_scale = (
92+
input0_q_params.scale * input1_q_params.scale
93+
) / output_q_params.scale
94+
95+
# As the input will be INT32, the input_zp must be set to 0
96+
build_rescale(
97+
tosa_fb=tosa_graph,
98+
scale=final_output_scale,
99+
input_node=reshape_intermediate,
100+
output_name=output.name,
101+
output_type=ts.DType.INT8,
102+
output_shape=reshape_intermediate.shape,
103+
input_zp=0,
104+
output_zp=output_q_params.zp,
105+
is_double_round=False,
106+
)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ class ArmQuantizer(Quantizer):
265265
"sub",
266266
"mul",
267267
"sigmoid",
268+
"mm",
268269
]
269270

270271
def __init__(self) -> None:

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def decorator(annotator: AnnotatorType):
5252
conv_annotator,
5353
linear_annotator,
5454
max_pool2d_annotator,
55+
mm_annotator,
5556
mul_annotator,
5657
sigmoid_annotator,
5758
sub_annotator,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
from typing import Callable, List, Optional
9+
10+
import torch
11+
from executorch.backends.arm.quantizer import arm_quantizer_utils
12+
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
13+
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
14+
from torch.ao.quantization.quantizer import QuantizationAnnotation
15+
from torch.fx import Node
16+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
17+
18+
19+
@register_annotator("mm")
20+
def _annotate_mm(
21+
gm: torch.fx.GraphModule,
22+
quantization_config: QuantizationConfig,
23+
filter_fn: Optional[Callable[[Node], bool]] = None,
24+
) -> Optional[List[List[Node]]]:
25+
mm_partitions = get_source_partitions(gm.graph, [torch.mm], filter_fn)
26+
mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values()))
27+
annotated_partitions = []
28+
for mm_partition in mm_partitions:
29+
annotated_partitions.append(mm_partition.nodes)
30+
mm_node = mm_partition.output_nodes[0]
31+
32+
if arm_quantizer_utils.is_annotated(mm_node):
33+
continue
34+
35+
input_act_qspec = quantization_config.get_input_act_qspec()
36+
output_act_qspec = quantization_config.get_output_act_qspec()
37+
38+
input_qspec_map = {}
39+
input_act0 = mm_node.args[0]
40+
if isinstance(input_act0, Node):
41+
if not arm_quantizer_utils.is_input_ok_for_quantization(input_act0, gm):
42+
continue
43+
input_qspec_map[input_act0] = input_act_qspec
44+
45+
input_act1 = mm_node.args[1]
46+
if isinstance(input_act1, Node):
47+
if not arm_quantizer_utils.is_input_ok_for_quantization(input_act1, gm):
48+
continue
49+
input_qspec_map[input_act1] = input_act_qspec
50+
51+
mm_node.meta["quantization_annotation"] = QuantizationAnnotation(
52+
input_qspec_map=input_qspec_map,
53+
output_qspec=output_act_qspec,
54+
_annotated=True,
55+
)
56+
return annotated_partitions

backends/arm/test/ops/test_mm.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import unittest
9+
10+
from typing import Tuple
11+
12+
import torch
13+
from executorch.backends.arm.test import common
14+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
15+
from parameterized import parameterized
16+
17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.INFO)
19+
20+
21+
class TestMM(unittest.TestCase):
22+
"""Tests MatMul"""
23+
24+
class MM(torch.nn.Module):
25+
test_parameters = [
26+
(torch.rand(3, 5), torch.rand(5, 2)),
27+
(torch.rand(1, 1), torch.rand(1, 1)),
28+
(torch.ones(55, 3), torch.ones(3, 44)),
29+
(10000 * torch.randn(1, 10), torch.randn(10, 5)),
30+
(-10 * torch.randn(32, 64), 5 + 5 * torch.randn(64, 32)),
31+
]
32+
33+
def forward(self, x, y):
34+
return torch.mm(x, y)
35+
36+
class MMSingleInput(torch.nn.Module):
37+
test_parameters = [
38+
(torch.rand(3, 3),),
39+
(torch.ones(128, 128),),
40+
(10000 * torch.randn(25, 25),),
41+
(5 + 5 * torch.randn(64, 64),),
42+
]
43+
44+
def forward(self, x):
45+
return torch.mm(x, x)
46+
47+
def _test_mm_tosa_MI_pipeline(
48+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
49+
):
50+
(
51+
ArmTester(
52+
module,
53+
example_inputs=test_data,
54+
compile_spec=common.get_tosa_compile_spec(),
55+
)
56+
.export()
57+
.check_count({"torch.ops.aten.mm.default": 1})
58+
.check_not(["torch.ops.quantized_decomposed"])
59+
.to_edge()
60+
.partition()
61+
.check_not(["executorch_exir_dialects_edge__ops_aten_mm_default"])
62+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
63+
.to_executorch()
64+
.run_method_and_compare_outputs(inputs=test_data)
65+
)
66+
67+
def _test_mm_tosa_BI_pipeline(
68+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
69+
):
70+
(
71+
ArmTester(
72+
module,
73+
example_inputs=test_data,
74+
compile_spec=common.get_tosa_compile_spec(),
75+
)
76+
.quantize()
77+
.export()
78+
.check_count({"torch.ops.aten.mm.default": 1})
79+
.check(["torch.ops.quantized_decomposed"])
80+
.to_edge()
81+
.partition()
82+
.check_not(["executorch_exir_dialects_edge__ops_aten_mm_default"])
83+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
84+
.to_executorch()
85+
.run_method_and_compare_outputs(inputs=test_data)
86+
)
87+
88+
def _test_mm_u55_BI_pipeline(
89+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
90+
):
91+
(
92+
ArmTester(
93+
module,
94+
example_inputs=test_data,
95+
compile_spec=common.get_u55_compile_spec(),
96+
)
97+
.quantize()
98+
.export()
99+
.check_count({"torch.ops.aten.mm.default": 1})
100+
.check(["torch.ops.quantized_decomposed"])
101+
.to_edge()
102+
.partition()
103+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
104+
.to_executorch()
105+
)
106+
107+
@parameterized.expand(MM.test_parameters)
108+
def test_mm_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
109+
test_data = (operand1, operand2)
110+
self._test_mm_tosa_MI_pipeline(self.MM(), test_data)
111+
112+
@parameterized.expand(MMSingleInput.test_parameters)
113+
def test_mm_single_input_tosa_MI(self, operand1: torch.Tensor):
114+
test_data = (operand1,)
115+
self._test_mm_tosa_MI_pipeline(self.MMSingleInput(), test_data)
116+
117+
@parameterized.expand(MM.test_parameters)
118+
def test_mm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
119+
test_data = (operand1, operand2)
120+
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)
121+
122+
@parameterized.expand(MMSingleInput.test_parameters)
123+
def test_mm_single_input_tosa_BI(self, operand1: torch.Tensor):
124+
test_data = (operand1,)
125+
self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data)
126+
127+
@parameterized.expand(MM.test_parameters)
128+
def test_mm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
129+
test_data = (operand1, operand2)
130+
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)
131+
132+
# Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
133+
@parameterized.expand(MMSingleInput.test_parameters)
134+
@unittest.expectedFailure
135+
def test_mm_single_input_u55_BI(self, operand1: torch.Tensor):
136+
test_data = (operand1,)
137+
self._test_mm_u55_BI_pipeline(self.MMSingleInput(), test_data)

backends/arm/tosa_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import logging
77
import os
8-
from typing import Dict
8+
from typing import Any, Dict
99

1010
import numpy as np
1111
import serializer.tosa_serializer as ts
@@ -316,3 +316,29 @@ def process_call_function(
316316
)
317317
else:
318318
raise RuntimeError(f"Unknown operator {node.target}")
319+
320+
321+
def expand_dims(
322+
tosa_graph: ts.TosaSerializer, input_node: TosaArg, dtype: ts.DType, dim: int
323+
) -> Any:
324+
"""Inserts TOSA operators into the tosa_graph, that perform the equivalent
325+
of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the
326+
dim location.
327+
328+
Args:
329+
tosa_graph (ts.TosaSerializer): The TOSA graph to manipulate.
330+
input_node (TosaArg): The parent node of the expand dim operations.
331+
dtype (ts.DType): The data type expand dims operations.
332+
dim (int): The dimension to expand.
333+
334+
Returns:
335+
Any: The output tensor of the inserted operation in the TOSA graph.
336+
"""
337+
new_shape = list(input_node.shape)
338+
new_shape.insert(dim, 1)
339+
340+
intermediate = tosa_graph.addIntermediate(new_shape, dtype)
341+
342+
build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name)
343+
344+
return intermediate

0 commit comments

Comments
 (0)