Skip to content

Add support for torch.pow in the Arm backend #9309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
from .remove_clone_pass import RemoveClonePass # noqa
from .replace_scalar_with_tensor_pass import ( # noqa
ReplaceScalarWithTensorArgPassTOSABI,
ReplaceScalarWithTensorArgPassTOSAMI,
)
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
Expand Down
13 changes: 6 additions & 7 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,17 @@
MatchArgRanksPass,
QuantizeOperatorArguments,
RemoveClonePass,
ReplaceScalarWithTensorArgPassTOSABI,
ReplaceScalarWithTensorArgPassTOSAMI,
RetraceFoldedDtypesPass,
ScalarsToAttributePass,
SizeAdjustConv2DPass,
UnsqueezeBeforeRepeatPass,
UnsqueezeScalarPlaceholdersPass,
)

from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform

from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
)
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.pass_manager import PassManager
Expand Down Expand Up @@ -84,7 +83,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
self.add_pass(CastToInt32Pass())

self.add_pass(ReplaceScalarWithTensorArgPass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
Expand Down Expand Up @@ -113,7 +112,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
return self._transform(exported_program.graph_module)

def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ReplaceScalarWithTensorArgPass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
Expand Down Expand Up @@ -170,7 +169,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
)

def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(ReplaceScalarWithTensorArgPass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeVarPass())
Expand Down
64 changes: 54 additions & 10 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

# pyre-unsafe

from typing import Callable, Dict
from itertools import chain
from typing import Callable, cast, Dict, Iterator, Set

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
Expand All @@ -17,7 +18,7 @@

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule

from torch.fx.node import Node
from torch.library import impl, Library

lib = Library("tosa", "DEF")
Expand All @@ -32,15 +33,13 @@ def _table_impl(*args, **kwargs): # pyre-ignore
return args[0].to(dtype=torch.int32)


class InsertTableOpsPass(ExportPass):
class TableOps:
"""
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
When lowering the _table node target_str will be used to find the corresponding torch operator
which will be used to produce the table values in operators/op_table.py.
Helper class for finding the corresponding table operator for a given Node.
"""

table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
# Targets that follow a straigtforward one-to-one mapping to their table op
unary_table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
exir_ops.edge.aten.ceil.default: torch.ceil,
exir_ops.edge.aten.exp.default: torch.exp,
exir_ops.edge.aten.floor.default: torch.floor,
Expand All @@ -53,9 +52,52 @@ class InsertTableOpsPass(ExportPass):
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
}

# Targets that must be treated explicitly
special_table_ops: Set[EdgeOpOverload] = {
exir_ops.edge.aten.pow.Tensor_Scalar,
}

def __init__(self, exported_program: ExportedProgram):
self.exported_program = exported_program

def __contains__(self, node: Node) -> bool:
return (
node.target in self.unary_table_ops or node.target in self.special_table_ops
)

def __getitem__(self, node: Node):
target = cast(EdgeOpOverload, node.target)
if target in self.unary_table_ops:
return self.unary_table_ops[target]
elif target in self.special_table_ops:
match target:
case exir_ops.edge.aten.pow.Tensor_Scalar:
# Exponent is a constant. Embed it into a lambda.
exp = cast(int, node.args[1])
return lambda x: torch.pow(x, exp).flatten()
case _:
# Op must be handled if it's inside self.special_ops
raise AssertionError("Unhandled table operation")
else:
raise KeyError("Table op for {target} does not exist")

@staticmethod
def included_ops() -> Iterator[EdgeOpOverload]:
return chain(TableOps.unary_table_ops, TableOps.special_table_ops)


class InsertTableOpsPass(ExportPass):
"""
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
When lowering the _table node target_str will be used to find the corresponding torch operator
which will be used to produce the table values in operators/op_table.py.
"""

def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
self.exported_program = exported_program
self.table_ops = TableOps(exported_program)

def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
"""
Expand Down Expand Up @@ -166,7 +208,7 @@ def generate_table_values(
def call(self, graph_module: GraphModule) -> PassResult:
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in self.table_ops:
if node.op != "call_function" or node not in self.table_ops:
continue
input_qparams = node.meta["input_qparams"]
output_qparams = node.meta["output_qparams"]
Expand All @@ -186,7 +228,7 @@ def call(self, graph_module: GraphModule) -> PassResult:

# Generate table buffer and how much to lshift the table output.
buffer, lshift = self.generate_table_values(
torch_op=self.table_ops[node.target],
torch_op=self.table_ops[node],
in_quantargs=input_qparams[0],
out_quantargs=output_qparams[0],
)
Expand All @@ -207,7 +249,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
output_node = rescale_node

node.replace_all_uses_with(output_node)

graph_module.graph.erase_node(node)

output_node.meta["input_qparams"] = input_qparams
output_node.meta["output_qparams"] = output_qparams
modified = True
Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/match_arg_ranks_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, exported_program):
exir_ops.edge.aten.bitwise_right_shift.Tensor,
exir_ops.edge.aten.bitwise_left_shift.Tensor,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.pow.Tensor_Tensor,
]

def _match_op_rank(self, graph_module, node, arg, max_rank):
Expand Down
53 changes: 53 additions & 0 deletions backends/arm/_passes/replace_scalar_with_tensor_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


from typing import Dict

import torch
from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
)
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.dialects.edge._ops import EdgeOpOverload


# Operators that are included for both TOSA profiles
_common_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor,
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor,
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
}


class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass):
scalar_to_tensor_ops = _common_ops | {
exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor,
torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor,
}

def __init__(self):
super().__init__(self.scalar_to_tensor_ops)


class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass):
scalar_to_tensor_ops = _common_ops

def __init__(self):
super().__init__(self.scalar_to_tensor_ops)
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def is_node_supported(
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.pow.Tensor_Tensor,
operator.getitem,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
op_minimum,
op_mul,
op_permute,
op_pow,
op_reciprocal,
op_repeat,
op_rescale,
Expand Down
57 changes: 57 additions & 0 deletions backends/arm/operators/op_pow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import List

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class PowVisitor_080_MI(NodeVisitor):
target = "aten.pow.Tensor_Tensor"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}"
)
if inputs[0].dtype not in [ts.DType.FP32, ts.DType.FP16]:
raise ValueError(
f"All inputs need to be FP32 or FP16. Got {inputs[0].dtype}"
)

tosa_graph.addOperator(
TosaOp.Op().POW,
[
inputs[0].name,
inputs[1].name,
],
[output.name],
None,
)
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _match_pattern(
torch.ops.aten.hardswish.default,
torch.ops.aten.hardswish_.default,
torch.ops.aten.full_like.default,
torch.ops.aten.pow.Tensor_Scalar,
]

_one_to_one_shared_input_qspec = [
Expand Down
Loading
Loading