Skip to content

Commit 5379416

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Add POW operator
Implement support for torch.pow in the MI and BI profile of TOSA. For MI, the operator works as Pytorch's reference implementation except for that the base operand cannot be a scalar but must be a tensor. For BI, the exponent operand must be a scalar and a constant value. The base operand must be a tensor. Split the ReplaceScalarWithTensorArgsPass into two subclasses: one for MI and one for BI. For MI, the pow operator's exponent will converted to a tensor in case it is a scalar. For BI, the scalar will be kept, but instead it will be consumed in the InsertTableOpsPass, meaning that the operator will be converted into a table operation with one input and output. This still enforces the exponent to be constant for the BI profile. Change-Id: I464ab91ff46c0a6ad28d0fb84735a403a74e6323
1 parent feb3fcd commit 5379416

File tree

11 files changed

+335
-25
lines changed

11 files changed

+335
-25
lines changed

backends/arm/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
4242
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
4343
from .remove_clone_pass import RemoveClonePass # noqa
44+
from .replace_scalar_with_tensor_pass import ( # noqa
45+
ReplaceScalarWithTensorArgPassTOSABI,
46+
ReplaceScalarWithTensorArgPassTOSAMI,
47+
)
4448
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
4549
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
4650
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,17 @@
4242
MatchArgRanksPass,
4343
QuantizeOperatorArguments,
4444
RemoveClonePass,
45+
ReplaceScalarWithTensorArgPassTOSABI,
46+
ReplaceScalarWithTensorArgPassTOSAMI,
4547
RetraceFoldedDtypesPass,
4648
ScalarsToAttributePass,
4749
SizeAdjustConv2DPass,
4850
UnsqueezeBeforeRepeatPass,
4951
UnsqueezeScalarPlaceholdersPass,
5052
)
53+
5154
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
5255
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
53-
54-
from executorch.backends.transforms.replace_scalar_with_tensor import (
55-
ReplaceScalarWithTensorArgPass,
56-
)
5756
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5857
from executorch.exir import ExportedProgram
5958
from executorch.exir.pass_manager import PassManager
@@ -84,7 +83,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8483
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
8584
self.add_pass(CastToInt32Pass())
8685

87-
self.add_pass(ReplaceScalarWithTensorArgPass())
86+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
8887
self.add_pass(AnnotateDecomposedMatmulPass())
8988
self.add_pass(QuantizeOperatorArguments())
9089
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -113,7 +112,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
113112
return self._transform(exported_program.graph_module)
114113

115114
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
116-
self.add_pass(ReplaceScalarWithTensorArgPass())
115+
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
117116
self.add_pass(FuseQuantizedActivationPass())
118117
self.add_pass(RemoveGetItemPass())
119118
self.add_pass(ConvertSplitToSlicePass())
@@ -170,7 +169,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
170169
)
171170

172171
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
173-
self.add_pass(ReplaceScalarWithTensorArgPass())
172+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
174173
self.add_pass(ScalarsToAttributePass())
175174
self.add_pass(DecomposeLayerNormPass())
176175
self.add_pass(DecomposeVarPass())

backends/arm/_passes/insert_table_ops.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
# pyre-unsafe
77

8-
from typing import Callable, Dict
8+
from itertools import chain
9+
from typing import Callable, cast, Dict, Iterator, Set
910

1011
import torch
1112
from executorch.backends.arm._passes.arm_pass_utils import create_node
@@ -17,7 +18,7 @@
1718

1819
from executorch.exir.pass_base import ExportPass, PassResult
1920
from torch.fx import GraphModule
20-
21+
from torch.fx.node import Node
2122
from torch.library import impl, Library
2223

2324
lib = Library("tosa", "DEF")
@@ -32,15 +33,13 @@ def _table_impl(*args, **kwargs): # pyre-ignore
3233
return args[0].to(dtype=torch.int32)
3334

3435

35-
class InsertTableOpsPass(ExportPass):
36+
class TableOps:
3637
"""
37-
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
38-
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
39-
When lowering the _table node target_str will be used to find the corresponding torch operator
40-
which will be used to produce the table values in operators/op_table.py.
38+
Helper class for finding the corresponding table operator for a given Node.
4139
"""
4240

43-
table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
41+
# Targets that follow a straigtforward one-to-one mapping to their table op
42+
unary_table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
4443
exir_ops.edge.aten.ceil.default: torch.ceil,
4544
exir_ops.edge.aten.exp.default: torch.exp,
4645
exir_ops.edge.aten.floor.default: torch.floor,
@@ -53,9 +52,52 @@ class InsertTableOpsPass(ExportPass):
5352
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
5453
}
5554

55+
# Targets that must be treated explicitly
56+
special_table_ops: Set[EdgeOpOverload] = {
57+
exir_ops.edge.aten.pow.Tensor_Scalar,
58+
}
59+
60+
def __init__(self, exported_program: ExportedProgram):
61+
self.exported_program = exported_program
62+
63+
def __contains__(self, node: Node) -> bool:
64+
return (
65+
node.target in self.unary_table_ops or node.target in self.special_table_ops
66+
)
67+
68+
def __getitem__(self, node: Node):
69+
target = cast(EdgeOpOverload, node.target)
70+
if target in self.unary_table_ops:
71+
return self.unary_table_ops[target]
72+
elif target in self.special_table_ops:
73+
match target:
74+
case exir_ops.edge.aten.pow.Tensor_Scalar:
75+
# Exponent is a constant. Embed it into a lambda.
76+
exp = cast(int, node.args[1])
77+
return lambda x: torch.pow(x, exp).flatten()
78+
case _:
79+
# Op must be handled if it's inside self.special_ops
80+
raise AssertionError("Unhandled table operation")
81+
else:
82+
raise KeyError("Table op for {target} does not exist")
83+
84+
@staticmethod
85+
def included_ops() -> Iterator[EdgeOpOverload]:
86+
return chain(TableOps.unary_table_ops, TableOps.special_table_ops)
87+
88+
89+
class InsertTableOpsPass(ExportPass):
90+
"""
91+
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
92+
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
93+
When lowering the _table node target_str will be used to find the corresponding torch operator
94+
which will be used to produce the table values in operators/op_table.py.
95+
"""
96+
5697
def __init__(self, exported_program: ExportedProgram) -> None:
5798
super().__init__()
5899
self.exported_program = exported_program
100+
self.table_ops = TableOps(exported_program)
59101

60102
def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
61103
"""
@@ -166,7 +208,7 @@ def generate_table_values(
166208
def call(self, graph_module: GraphModule) -> PassResult:
167209
modified = False
168210
for node in graph_module.graph.nodes:
169-
if node.op != "call_function" or node.target not in self.table_ops:
211+
if node.op != "call_function" or node not in self.table_ops:
170212
continue
171213
input_qparams = node.meta["input_qparams"]
172214
output_qparams = node.meta["output_qparams"]
@@ -186,7 +228,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
186228

187229
# Generate table buffer and how much to lshift the table output.
188230
buffer, lshift = self.generate_table_values(
189-
torch_op=self.table_ops[node.target],
231+
torch_op=self.table_ops[node],
190232
in_quantargs=input_qparams[0],
191233
out_quantargs=output_qparams[0],
192234
)
@@ -207,7 +249,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
207249
output_node = rescale_node
208250

209251
node.replace_all_uses_with(output_node)
252+
210253
graph_module.graph.erase_node(node)
254+
211255
output_node.meta["input_qparams"] = input_qparams
212256
output_node.meta["output_qparams"] = output_qparams
213257
modified = True

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(self, exported_program):
4848
exir_ops.edge.aten.bitwise_right_shift.Tensor,
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
5050
exir_ops.edge.aten.eq.Tensor,
51+
exir_ops.edge.aten.pow.Tensor_Tensor,
5152
]
5253

5354
def _match_op_rank(self, graph_module, node, arg, max_rank):
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
from typing import Dict
10+
11+
import torch
12+
from executorch.backends.transforms.replace_scalar_with_tensor import (
13+
ReplaceScalarWithTensorArgPass,
14+
)
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
18+
19+
20+
# Operators that are included for both TOSA profiles
21+
_common_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
22+
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
23+
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
24+
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
25+
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
26+
exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor,
27+
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
28+
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
29+
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
30+
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
31+
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
32+
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
33+
torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor,
34+
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
35+
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
36+
}
37+
38+
39+
class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass):
40+
scalar_to_tensor_ops = _common_ops | {
41+
exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor,
42+
torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor,
43+
}
44+
45+
def __init__(self):
46+
super().__init__(self.scalar_to_tensor_ops)
47+
48+
49+
class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass):
50+
scalar_to_tensor_ops = _common_ops
51+
52+
def __init__(self):
53+
super().__init__(self.scalar_to_tensor_ops)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def is_node_supported(
198198
exir_ops.edge.aten.clone.default,
199199
exir_ops.edge.aten.unsqueeze_copy.default,
200200
exir_ops.edge.aten.squeeze_copy.dims,
201+
exir_ops.edge.aten.pow.Tensor_Scalar,
202+
exir_ops.edge.aten.pow.Tensor_Tensor,
201203
operator.getitem,
202204
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
203205
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
op_minimum,
3333
op_mul,
3434
op_permute,
35+
op_pow,
3536
op_reciprocal,
3637
op_repeat,
3738
op_rescale,

backends/arm/operators/op_pow.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import serializer.tosa_serializer as ts
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
17+
from serializer.tosa_serializer import TosaOp
18+
from torch.fx import Node
19+
20+
21+
@register_node_visitor
22+
class PowVisitor_080_MI(NodeVisitor):
23+
target = "aten.pow.Tensor_Tensor"
24+
25+
tosa_specs = [
26+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
27+
]
28+
29+
def __init__(self, *args):
30+
super().__init__(*args)
31+
32+
def define_node(
33+
self,
34+
node: Node,
35+
tosa_graph: ts.TosaSerializer,
36+
inputs: List[TosaArg],
37+
output: TosaArg,
38+
) -> None:
39+
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
40+
raise ValueError(
41+
"All inputs and outputs need same dtype."
42+
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}"
43+
)
44+
if inputs[0].dtype not in [ts.DType.FP32, ts.DType.FP16]:
45+
raise ValueError(
46+
f"All inputs need to be FP32 or FP16. Got {inputs[0].dtype}"
47+
)
48+
49+
tosa_graph.addOperator(
50+
TosaOp.Op().POW,
51+
[
52+
inputs[0].name,
53+
inputs[1].name,
54+
],
55+
[output.name],
56+
None,
57+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def _match_pattern(
139139
torch.ops.aten.hardswish.default,
140140
torch.ops.aten.hardswish_.default,
141141
torch.ops.aten.full_like.default,
142+
torch.ops.aten.pow.Tensor_Scalar,
142143
]
143144

144145
_one_to_one_shared_input_qspec = [

0 commit comments

Comments
 (0)