Skip to content

Implement mm op for Arm backend #4628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.slice_copy.Tensor,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
op_get_item,
op_hardtanh,
op_mean_dim,
op_mm,
op_permute,
op_quant,
op_repeat,
Expand Down
106 changes: 106 additions & 0 deletions backends/arm/operators/op_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
from executorch.backends.arm.tosa_utils import (
build_reshape,
expand_dims,
get_two_inputs,
)
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class MMVisitor(NodeVisitor):
target = "aten.mm.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
input0, input1 = get_two_inputs(node)

# For atem.mm, the two inputs are of rank 2
# For TOSA it needs to be rank 3
# So they need to be reshaped from (H, W) to (1, H, W)
# NOTE: For now, only INT8 & FP32 is supported
reshape_dtype = ts.DType.INT8 if is_quant_node else ts.DType.FP32
input0_reshaped = expand_dims(tosa_graph, inputs[0], reshape_dtype, 0)
input1_reshaped = expand_dims(tosa_graph, inputs[1], reshape_dtype, 0)

# The output also needs to be rank 3
output_new_shape = (1, output.shape[0], output.shape[1])

# For INT8, we need to get the zero point, otherwise it is 0
input0_zp, input1_zp = 0, 0
if is_quant_node:
input0_zp = get_quant_node_args(input0).zp
input1_zp = get_quant_node_args(input1).zp

mat_mul_result = tosa_graph.addIntermediate(
output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype
)

attr = ts.TosaSerializerAttribute()
attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp)

tosa_graph.addOperator(
TosaOp.Op().MATMUL,
[input0_reshaped.name, input1_reshaped.name],
[mat_mul_result.name],
attr,
)

if is_quant_node:
reshape_intermediate = tosa_graph.addIntermediate(
output.shape, ts.DType.INT32
)
reshape_output_name = reshape_intermediate.name
else:
reshape_output_name = output.name

# Reshape the final output back to rank 2
build_reshape(
tosa_graph, mat_mul_result.name, output.shape, reshape_output_name
)

# As INT8 accumulates into INT32, we need to rescale it back to INT8
if is_quant_node:
input0_q_params = get_quant_node_args(input0)
input1_q_params = get_quant_node_args(input1)
output_q_params = get_quant_node_args(list(node.users)[0])

final_output_scale = (
input0_q_params.scale * input1_q_params.scale
) / output_q_params.scale

# As the input will be INT32, the input_zp must be set to 0
build_rescale(
tosa_fb=tosa_graph,
scale=final_output_scale,
input_node=reshape_intermediate,
output_name=output.name,
output_type=ts.DType.INT8,
output_shape=reshape_intermediate.shape,
input_zp=0,
output_zp=output_q_params.zp,
is_double_round=False,
)
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class ArmQuantizer(Quantizer):
"sub",
"mul",
"sigmoid",
"mm",
]

def __init__(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def decorator(annotator: AnnotatorType):
conv_annotator,
linear_annotator,
max_pool2d_annotator,
mm_annotator,
mul_annotator,
sigmoid_annotator,
sub_annotator,
Expand Down
56 changes: 56 additions & 0 deletions backends/arm/quantizer/quantization_annotation/mm_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
from typing import Callable, List, Optional

import torch
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


@register_annotator("mm")
def _annotate_mm(
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
mm_partitions = get_source_partitions(gm.graph, [torch.mm], filter_fn)
mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values()))
annotated_partitions = []
for mm_partition in mm_partitions:
annotated_partitions.append(mm_partition.nodes)
mm_node = mm_partition.output_nodes[0]

if arm_quantizer_utils.is_annotated(mm_node):
continue

input_act_qspec = quantization_config.get_input_act_qspec()
output_act_qspec = quantization_config.get_output_act_qspec()

input_qspec_map = {}
input_act0 = mm_node.args[0]
if isinstance(input_act0, Node):
if not arm_quantizer_utils.is_input_ok_for_quantization(input_act0, gm):
continue
input_qspec_map[input_act0] = input_act_qspec

input_act1 = mm_node.args[1]
if isinstance(input_act1, Node):
if not arm_quantizer_utils.is_input_ok_for_quantization(input_act1, gm):
continue
input_qspec_map[input_act1] = input_act_qspec

mm_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
return annotated_partitions
137 changes: 137 additions & 0 deletions backends/arm/test/ops/test_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import unittest

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class TestMM(unittest.TestCase):
"""Tests MatMul"""

class MM(torch.nn.Module):
test_parameters = [
(torch.rand(3, 5), torch.rand(5, 2)),
(torch.rand(1, 1), torch.rand(1, 1)),
(torch.ones(55, 3), torch.ones(3, 44)),
(10000 * torch.randn(1, 10), torch.randn(10, 5)),
(-10 * torch.randn(32, 64), 5 + 5 * torch.randn(64, 32)),
]

def forward(self, x, y):
return torch.mm(x, y)

class MMSingleInput(torch.nn.Module):
test_parameters = [
(torch.rand(3, 3),),
(torch.ones(128, 128),),
(10000 * torch.randn(25, 25),),
(5 + 5 * torch.randn(64, 64),),
]

def forward(self, x):
return torch.mm(x, x)

def _test_mm_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.check_count({"torch.ops.aten.mm.default": 1})
.check_not(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_mm_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_mm_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.quantize()
.export()
.check_count({"torch.ops.aten.mm.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_mm_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_mm_u55_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(),
)
.quantize()
.export()
.check_count({"torch.ops.aten.mm.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

@parameterized.expand(MM.test_parameters)
def test_mm_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_mm_tosa_MI_pipeline(self.MM(), test_data)

@parameterized.expand(MMSingleInput.test_parameters)
def test_mm_single_input_tosa_MI(self, operand1: torch.Tensor):
test_data = (operand1,)
self._test_mm_tosa_MI_pipeline(self.MMSingleInput(), test_data)

@parameterized.expand(MM.test_parameters)
def test_mm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)

@parameterized.expand(MMSingleInput.test_parameters)
def test_mm_single_input_tosa_BI(self, operand1: torch.Tensor):
test_data = (operand1,)
self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data)

@parameterized.expand(MM.test_parameters)
def test_mm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)

# Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
@parameterized.expand(MMSingleInput.test_parameters)
@unittest.expectedFailure
def test_mm_single_input_u55_BI(self, operand1: torch.Tensor):
test_data = (operand1,)
self._test_mm_u55_BI_pipeline(self.MMSingleInput(), test_data)
28 changes: 27 additions & 1 deletion backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging
import os
from typing import Dict
from typing import Any, Dict

import numpy as np
import serializer.tosa_serializer as ts
Expand Down Expand Up @@ -316,3 +316,29 @@ def process_call_function(
)
else:
raise RuntimeError(f"Unknown operator {node.target}")


def expand_dims(
tosa_graph: ts.TosaSerializer, input_node: TosaArg, dtype: ts.DType, dim: int
) -> Any:
"""Inserts TOSA operators into the tosa_graph, that perform the equivalent
of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the
dim location.

Args:
tosa_graph (ts.TosaSerializer): The TOSA graph to manipulate.
input_node (TosaArg): The parent node of the expand dim operations.
dtype (ts.DType): The data type expand dims operations.
dim (int): The dimension to expand.

Returns:
Any: The output tensor of the inserted operation in the TOSA graph.
"""
new_shape = list(input_node.shape)
new_shape.insert(dim, 1)

intermediate = tosa_graph.addIntermediate(new_shape, dtype)

build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name)

return intermediate
Loading