Skip to content

Add reduce_sum op to ArmBackend #6044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
ConvertSplitToSlicePass,
)
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
InsertSqueezeAfterSumPass,
)
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
ConvertMeanDimToAveragePool,
)
Expand Down Expand Up @@ -47,6 +50,7 @@ def transform_to_backend_pipeline(
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(DecomposeDivPass())
self.add_pass(InsertSqueezeAfterSumPass())
self.add_pass(ConvertSplitToSlicePass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
Expand Down
69 changes: 69 additions & 0 deletions backends/arm/_passes/insert_squeeze_after_sum_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import torch
import torch.fx
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair

from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class InsertSqueezeAfterSumPass(ExportPass):
"""
In Pytorch, the default behaviour of Tensor.sum is to squeeze
the dimension that is summed (keep_dim = False).
However, in TOSA, REDUCE_SUM always preserves the
rank of the input (keep_dim = True).
To get a 1-1 mapping in the sum lowering, normalize the
keep_dim = False case to keep_dim = True and add squeeze ops.

Original:
sum(dims, keep_dim = False)
After pass:
sum(dims, keep_dim = True)
(q)
(dq)
squeeze(dim = dims)
"""

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target != exir_ops.edge.aten.sum.dim_IntList:
continue
sum_node = cast(torch.fx.Node, node)
keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False)
if keep_dim:
continue

dim_list = cast(list[int], sum_node.args[1])
quantized = is_quant_node(sum_node)
if quantized:
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
qparams = qparams + (torch.int8,)
else:
qparams = None

# Add keep_dim = True arg to sum node.
sum_node.args = sum_node.args[0:2] + (True,)

with graph_module.graph.inserting_after(sum_node):
squeeze_node = create_node(
graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, ()
)
sum_node.replace_all_uses_with(squeeze_node)
squeeze_node.args = (sum_node, dim_list)
if quantized:
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.mean.dim,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
op_softmax,
op_squeeze,
op_sub,
op_sum,
op_unsqueeze,
op_view,
)
96 changes: 96 additions & 0 deletions backends/arm/operators/op_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class AddVisitor(NodeVisitor):
target = "aten.sum.dim_IntList"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
input_node = inputs[0]
input_shape = list(input_node.shape)
dim_list = cast(list[int], inputs[1].special)
dim_list = [dim % len(input_node.shape) for dim in dim_list]
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass"

if is_quant_node:

# Rescale input to 32 bit
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
[node.all_input_nodes[0]], tosa_graph
)

prev_node = rescaled_inputs[0]
reduced_shape = input_shape

# Reduce all dims in dim_list one-by-one.
for dim in dim_list:
# When reduced, the size of the dim becomes 1.
reduced_shape[dim] = 1

attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(input_node.dim_order.index(dim))

next_node = tosa_graph.addIntermediate(
tutils.tosa_shape(reduced_shape, input_node.dim_order),
dtype=ts.DType.INT32,
)

tosa_graph.addOperator(
TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr
)

prev_node = next_node
tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph)
else:
input_name = input_node.name
reduced_shape = input_shape

# Reduce all dims in dim_list one-by-one.
for dim in dim_list:
# When reduced, the size of the dim becomes 1
reduced_shape[dim] = 1

attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(input_node.dim_order.index(dim))

if dim == dim_list[-1]:
output_name = output.name
else:
output_name = tosa_graph.addIntermediate(
tutils.tosa_shape(reduced_shape, input_node.dim_order),
dtype=ts.DType.FP32,
).name

tosa_graph.addOperator(
TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr
)

input_name = output_name
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ class ArmQuantizer(Quantizer):
"cat",
"one_to_one",
"generic",
"sum",
]

def __init__(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@ def decorator(annotator: AnnotatorType):
one_to_one_annotator,
sigmoid_annotator,
sub_annotator,
sum_annotator,
)
57 changes: 57 additions & 0 deletions backends/arm/quantizer/quantization_annotation/sum_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, cast, List, Optional

import torch
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig

from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
QuantizationSpecBase,
SharedQuantizationSpec,
)
from torch.fx import Node


@register_annotator("sum")
def _annotate_sum(
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
annotated_partitions = []
for node in gm.graph.nodes:
if node.target is not torch.ops.aten.sum.dim_IntList:
continue
if filter_fn and not filter_fn(node):
continue

sum_node = node
if arm_quantizer_utils.is_annotated(sum_node):
continue

input_act = sum_node.args[0]

if not isinstance(input_act, Node):
continue
if not arm_quantizer_utils.is_input_ok_for_quantization(input_act, gm):
continue

input_act_qspec = cast(
Optional[QuantizationSpecBase], quantization_config.get_input_act_qspec()
)
input_qspec_map = {input_act: input_act_qspec}
shared_with_input0_qspec = SharedQuantizationSpec((input_act, sum_node))

sum_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=shared_with_input0_qspec,
_annotated=True,
)
annotated_partitions.append([sum_node])
return annotated_partitions
Loading
Loading