Skip to content

Commit 900ca66

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Add POW operator
Implement support for torch.pow in the MI and BI profile of TOSA. For MI, the operator works as Pytorch's reference implementation except for that the base operand cannot be a scalar but must be a tensor. For BI, the exponent operand must be a scalar and a constant value. The base operand must be a tensor. Split the ReplaceScalarWithTensorArgsPass into two subclasses: one for MI and one for BI. For MI, the pow operator's exponent will converted to a tensor in case it is a scalar. For BI, the scalar will be kept, but instead it will be consumed in the InsertTableOpsPass, meaning that the operator will be converted into a table operation with one input and output. This still enforces the exponent to be constant for the BI profile. Change-Id: I464ab91ff46c0a6ad28d0fb84735a403a74e6323
1 parent 265b9b7 commit 900ca66

File tree

10 files changed

+331
-19
lines changed

10 files changed

+331
-19
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@
7575
ConvertMmToBmmPass,
7676
)
7777
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
78+
79+
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
80+
ReplaceScalarWithTensorArgPassTOSABI,
81+
ReplaceScalarWithTensorArgPassTOSAMI,
82+
)
7883
from executorch.backends.arm._passes.scalars_to_attribute_pass import (
7984
ScalarsToAttributePass,
8085
)
@@ -87,10 +92,6 @@
8792
)
8893
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
8994
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
90-
91-
from executorch.backends.transforms.replace_scalar_with_tensor import (
92-
ReplaceScalarWithTensorArgPass,
93-
)
9495
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
9596
from executorch.exir import ExportedProgram
9697
from executorch.exir.pass_manager import PassManager
@@ -119,7 +120,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
119120
self.add_pass(ConvertMinMaxPass())
120121
self.add_pass(ConvertAnyDefaultDimDimsPass())
121122

122-
self.add_pass(ReplaceScalarWithTensorArgPass())
123+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
123124
self.add_pass(AnnotateDecomposedMatmulPass())
124125
self.add_pass(QuantizeOperatorArguments())
125126
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -148,7 +149,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
148149
return self._transform(exported_program.graph_module)
149150

150151
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
151-
self.add_pass(ReplaceScalarWithTensorArgPass())
152+
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
152153
self.add_pass(FuseQuantizedActivationPass())
153154
self.add_pass(RemoveGetItemPass())
154155
self.add_pass(ConvertSplitToSlicePass())
@@ -205,7 +206,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
205206
)
206207

207208
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
208-
self.add_pass(ReplaceScalarWithTensorArgPass())
209+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
209210
self.add_pass(ScalarsToAttributePass())
210211
self.add_pass(DecomposeLayerNormPass())
211212
self.add_pass(DecomposeVarPass())

backends/arm/_passes/insert_table_ops.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
# pyre-unsafe
77

8-
from typing import Callable, Dict
8+
from itertools import chain
9+
from typing import Callable, cast, Dict, Iterator, Set
910

1011
import torch
1112
from executorch.backends.arm._passes.arm_pass_utils import create_node
@@ -17,7 +18,7 @@
1718

1819
from executorch.exir.pass_base import ExportPass, PassResult
1920
from torch.fx import GraphModule
20-
21+
from torch.fx.node import Node
2122
from torch.library import impl, Library
2223

2324
lib = Library("tosa", "DEF")
@@ -32,15 +33,13 @@ def _table_impl(*args, **kwargs): # pyre-ignore
3233
return args[0].to(dtype=torch.int32)
3334

3435

35-
class InsertTableOpsPass(ExportPass):
36+
class TableOps:
3637
"""
37-
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
38-
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
39-
When lowering the _table node target_str will be used to find the corresponding torch operator
40-
which will be used to produce the table values in operators/op_table.py.
38+
Helper class for finding the corresponding table operator for a given Node.
4139
"""
4240

43-
table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
41+
# Targets that follow a straigtforward one-to-one mapping to their table op
42+
unary_table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
4443
exir_ops.edge.aten.ceil.default: torch.ceil,
4544
exir_ops.edge.aten.exp.default: torch.exp,
4645
exir_ops.edge.aten.floor.default: torch.floor,
@@ -53,9 +52,52 @@ class InsertTableOpsPass(ExportPass):
5352
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
5453
}
5554

55+
# Targets that must be treated explicitly
56+
special_table_ops: Set[EdgeOpOverload] = {
57+
exir_ops.edge.aten.pow.Tensor_Scalar,
58+
}
59+
60+
def __init__(self, exported_program: ExportedProgram):
61+
self.exported_program = exported_program
62+
63+
def __contains__(self, node: Node) -> bool:
64+
return (
65+
node.target in self.unary_table_ops or node.target in self.special_table_ops
66+
)
67+
68+
def __getitem__(self, node: Node):
69+
target = cast(EdgeOpOverload, node.target)
70+
if target in self.unary_table_ops:
71+
return self.unary_table_ops[target]
72+
elif target in self.special_table_ops:
73+
match target:
74+
case exir_ops.edge.aten.pow.Tensor_Scalar:
75+
# Exponent is a constant. Embed it into a lambda.
76+
exp = cast(int, node.args[1])
77+
return lambda x: torch.pow(x, exp).flatten()
78+
case _:
79+
# Op must be handled if it's inside self.special_ops
80+
raise AssertionError("Unhandled table operation")
81+
else:
82+
raise KeyError("Table op for {target} does not exist")
83+
84+
@staticmethod
85+
def included_ops() -> Iterator[EdgeOpOverload]:
86+
return chain(TableOps.unary_table_ops, TableOps.special_table_ops)
87+
88+
89+
class InsertTableOpsPass(ExportPass):
90+
"""
91+
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
92+
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
93+
When lowering the _table node target_str will be used to find the corresponding torch operator
94+
which will be used to produce the table values in operators/op_table.py.
95+
"""
96+
5697
def __init__(self, exported_program: ExportedProgram) -> None:
5798
super().__init__()
5899
self.exported_program = exported_program
100+
self.table_ops = TableOps(exported_program)
59101

60102
def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
61103
"""
@@ -166,7 +208,7 @@ def generate_table_values(
166208
def call(self, graph_module: GraphModule) -> PassResult:
167209
modified = False
168210
for node in graph_module.graph.nodes:
169-
if node.op != "call_function" or node.target not in self.table_ops:
211+
if node.op != "call_function" or node not in self.table_ops:
170212
continue
171213
input_qparams = node.meta["input_qparams"]
172214
output_qparams = node.meta["output_qparams"]
@@ -186,7 +228,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
186228

187229
# Generate table buffer and how much to lshift the table output.
188230
buffer, lshift = self.generate_table_values(
189-
torch_op=self.table_ops[node.target],
231+
torch_op=self.table_ops[node],
190232
in_quantargs=input_qparams[0],
191233
out_quantargs=output_qparams[0],
192234
)
@@ -207,7 +249,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
207249
output_node = rescale_node
208250

209251
node.replace_all_uses_with(output_node)
252+
210253
graph_module.graph.erase_node(node)
254+
211255
output_node.meta["input_qparams"] = input_qparams
212256
output_node.meta["output_qparams"] = output_qparams
213257
modified = True

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self, exported_program):
4545
exir_ops.edge.aten.sub.Tensor,
4646
exir_ops.edge.aten.mul.Tensor,
4747
exir_ops.edge.aten.div.Tensor,
48+
exir_ops.edge.aten.pow.Tensor_Tensor,
4849
]
4950

5051
def _match_op_rank(self, graph_module, node, arg, max_rank):
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
from typing import Dict
10+
11+
import torch
12+
from executorch.backends.transforms.replace_scalar_with_tensor import (
13+
ReplaceScalarWithTensorArgPass,
14+
)
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
18+
19+
20+
class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass):
21+
scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
22+
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
23+
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
24+
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
25+
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
26+
exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor,
27+
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
28+
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
29+
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
30+
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
31+
torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor,
32+
}
33+
34+
def __init__(self):
35+
super().__init__(self.scalar_to_tensor_ops)
36+
37+
38+
class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass):
39+
scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
40+
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
41+
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
42+
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
43+
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
44+
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
45+
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
46+
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
47+
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
48+
}
49+
50+
def __init__(self):
51+
super().__init__(self.scalar_to_tensor_ops)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ def is_node_supported(
197197
exir_ops.edge.aten.clone.default,
198198
exir_ops.edge.aten.unsqueeze_copy.default,
199199
exir_ops.edge.aten.squeeze_copy.dims,
200+
exir_ops.edge.aten.pow.Tensor_Scalar,
201+
exir_ops.edge.aten.pow.Tensor_Tensor,
200202
operator.getitem,
201203
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
202204
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
op_minimum,
3333
op_mul,
3434
op_permute,
35+
op_pow,
3536
op_reciprocal,
3637
op_repeat,
3738
op_rescale,

backends/arm/operators/op_pow.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import serializer.tosa_serializer as ts
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
17+
from serializer.tosa_serializer import TosaOp
18+
from torch.fx import Node
19+
20+
21+
@register_node_visitor
22+
class PowVisitor_080_MI(NodeVisitor):
23+
target = "aten.pow.Tensor_Tensor"
24+
25+
tosa_specs = [
26+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
27+
]
28+
29+
def __init__(self, *args):
30+
super().__init__(*args)
31+
32+
def define_node(
33+
self,
34+
node: Node,
35+
tosa_graph: ts.TosaSerializer,
36+
inputs: List[TosaArg],
37+
output: TosaArg,
38+
) -> None:
39+
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
40+
raise ValueError(
41+
"All inputs and outputs need same dtype."
42+
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}"
43+
)
44+
if inputs[0].dtype not in [ts.DType.FP32, ts.DType.FP16]:
45+
raise ValueError(
46+
f"All inputs need to be FP32 or FP16. Got {inputs[0].dtype}"
47+
)
48+
49+
tosa_graph.addOperator(
50+
TosaOp.Op().POW,
51+
[
52+
inputs[0].name,
53+
inputs[1].name,
54+
],
55+
[output.name],
56+
None,
57+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def _match_pattern(
138138
torch.ops.aten.hardsigmoid.default,
139139
torch.ops.aten.hardswish.default,
140140
torch.ops.aten.full_like.default,
141+
torch.ops.aten.pow.Tensor_Scalar,
141142
]
142143

143144
_one_to_one_shared_input_qspec = [

0 commit comments

Comments
 (0)