Skip to content

Commit a5e326a

Browse files
martinlsmMartin Lindström
andauthored
Add support for torch.pow in the Arm backend (#9309)
Co-authored-by: Martin Lindström <[email protected]>
1 parent 3e62f9e commit a5e326a

File tree

11 files changed

+335
-25
lines changed

11 files changed

+335
-25
lines changed

backends/arm/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
4343
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
4444
from .remove_clone_pass import RemoveClonePass # noqa
45+
from .replace_scalar_with_tensor_pass import ( # noqa
46+
ReplaceScalarWithTensorArgPassTOSABI,
47+
ReplaceScalarWithTensorArgPassTOSAMI,
48+
)
4549
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
4650
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
4751
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,17 @@
4242
MatchArgRanksPass,
4343
QuantizeOperatorArguments,
4444
RemoveClonePass,
45+
ReplaceScalarWithTensorArgPassTOSABI,
46+
ReplaceScalarWithTensorArgPassTOSAMI,
4547
RetraceFoldedDtypesPass,
4648
ScalarsToAttributePass,
4749
SizeAdjustConv2DPass,
4850
UnsqueezeBeforeRepeatPass,
4951
UnsqueezeScalarPlaceholdersPass,
5052
)
53+
5154
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
5255
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
53-
54-
from executorch.backends.transforms.replace_scalar_with_tensor import (
55-
ReplaceScalarWithTensorArgPass,
56-
)
5756
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5857
from executorch.exir import ExportedProgram
5958
from executorch.exir.pass_manager import PassManager
@@ -84,7 +83,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8483
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
8584
self.add_pass(CastToInt32Pass())
8685

87-
self.add_pass(ReplaceScalarWithTensorArgPass())
86+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
8887
self.add_pass(AnnotateDecomposedMatmulPass())
8988
self.add_pass(QuantizeOperatorArguments())
9089
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -113,7 +112,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
113112
return self._transform(exported_program.graph_module)
114113

115114
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
116-
self.add_pass(ReplaceScalarWithTensorArgPass())
115+
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
117116
self.add_pass(FuseQuantizedActivationPass())
118117
self.add_pass(RemoveGetItemPass())
119118
self.add_pass(ConvertSplitToSlicePass())
@@ -170,7 +169,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
170169
)
171170

172171
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
173-
self.add_pass(ReplaceScalarWithTensorArgPass())
172+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
174173
self.add_pass(ScalarsToAttributePass())
175174
self.add_pass(DecomposeLayerNormPass())
176175
self.add_pass(DecomposeVarPass())

backends/arm/_passes/insert_table_ops.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
# pyre-unsafe
77

8-
from typing import Callable, Dict
8+
from itertools import chain
9+
from typing import Callable, cast, Dict, Iterator, Set
910

1011
import torch
1112
from executorch.backends.arm._passes.arm_pass_utils import create_node
@@ -17,7 +18,7 @@
1718

1819
from executorch.exir.pass_base import ExportPass, PassResult
1920
from torch.fx import GraphModule
20-
21+
from torch.fx.node import Node
2122
from torch.library import impl, Library
2223

2324
lib = Library("tosa", "DEF")
@@ -32,15 +33,13 @@ def _table_impl(*args, **kwargs): # pyre-ignore
3233
return args[0].to(dtype=torch.int32)
3334

3435

35-
class InsertTableOpsPass(ExportPass):
36+
class TableOps:
3637
"""
37-
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
38-
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
39-
When lowering the _table node target_str will be used to find the corresponding torch operator
40-
which will be used to produce the table values in operators/op_table.py.
38+
Helper class for finding the corresponding table operator for a given Node.
4139
"""
4240

43-
table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
41+
# Targets that follow a straigtforward one-to-one mapping to their table op
42+
unary_table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
4443
exir_ops.edge.aten.ceil.default: torch.ceil,
4544
exir_ops.edge.aten.exp.default: torch.exp,
4645
exir_ops.edge.aten.floor.default: torch.floor,
@@ -53,9 +52,52 @@ class InsertTableOpsPass(ExportPass):
5352
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
5453
}
5554

55+
# Targets that must be treated explicitly
56+
special_table_ops: Set[EdgeOpOverload] = {
57+
exir_ops.edge.aten.pow.Tensor_Scalar,
58+
}
59+
60+
def __init__(self, exported_program: ExportedProgram):
61+
self.exported_program = exported_program
62+
63+
def __contains__(self, node: Node) -> bool:
64+
return (
65+
node.target in self.unary_table_ops or node.target in self.special_table_ops
66+
)
67+
68+
def __getitem__(self, node: Node):
69+
target = cast(EdgeOpOverload, node.target)
70+
if target in self.unary_table_ops:
71+
return self.unary_table_ops[target]
72+
elif target in self.special_table_ops:
73+
match target:
74+
case exir_ops.edge.aten.pow.Tensor_Scalar:
75+
# Exponent is a constant. Embed it into a lambda.
76+
exp = cast(int, node.args[1])
77+
return lambda x: torch.pow(x, exp).flatten()
78+
case _:
79+
# Op must be handled if it's inside self.special_ops
80+
raise AssertionError("Unhandled table operation")
81+
else:
82+
raise KeyError("Table op for {target} does not exist")
83+
84+
@staticmethod
85+
def included_ops() -> Iterator[EdgeOpOverload]:
86+
return chain(TableOps.unary_table_ops, TableOps.special_table_ops)
87+
88+
89+
class InsertTableOpsPass(ExportPass):
90+
"""
91+
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
92+
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
93+
When lowering the _table node target_str will be used to find the corresponding torch operator
94+
which will be used to produce the table values in operators/op_table.py.
95+
"""
96+
5697
def __init__(self, exported_program: ExportedProgram) -> None:
5798
super().__init__()
5899
self.exported_program = exported_program
100+
self.table_ops = TableOps(exported_program)
59101

60102
def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
61103
"""
@@ -166,7 +208,7 @@ def generate_table_values(
166208
def call(self, graph_module: GraphModule) -> PassResult:
167209
modified = False
168210
for node in graph_module.graph.nodes:
169-
if node.op != "call_function" or node.target not in self.table_ops:
211+
if node.op != "call_function" or node not in self.table_ops:
170212
continue
171213
input_qparams = node.meta["input_qparams"]
172214
output_qparams = node.meta["output_qparams"]
@@ -186,7 +228,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
186228

187229
# Generate table buffer and how much to lshift the table output.
188230
buffer, lshift = self.generate_table_values(
189-
torch_op=self.table_ops[node.target],
231+
torch_op=self.table_ops[node],
190232
in_quantargs=input_qparams[0],
191233
out_quantargs=output_qparams[0],
192234
)
@@ -207,7 +249,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
207249
output_node = rescale_node
208250

209251
node.replace_all_uses_with(output_node)
252+
210253
graph_module.graph.erase_node(node)
254+
211255
output_node.meta["input_qparams"] = input_qparams
212256
output_node.meta["output_qparams"] = output_qparams
213257
modified = True

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(self, exported_program):
4848
exir_ops.edge.aten.bitwise_right_shift.Tensor,
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
5050
exir_ops.edge.aten.eq.Tensor,
51+
exir_ops.edge.aten.pow.Tensor_Tensor,
5152
]
5253

5354
def _match_op_rank(self, graph_module, node, arg, max_rank):
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
from typing import Dict
10+
11+
import torch
12+
from executorch.backends.transforms.replace_scalar_with_tensor import (
13+
ReplaceScalarWithTensorArgPass,
14+
)
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
18+
19+
20+
# Operators that are included for both TOSA profiles
21+
_common_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
22+
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
23+
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
24+
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
25+
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
26+
exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor,
27+
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
28+
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
29+
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
30+
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
31+
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
32+
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
33+
torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor,
34+
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
35+
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
36+
}
37+
38+
39+
class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass):
40+
scalar_to_tensor_ops = _common_ops | {
41+
exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor,
42+
torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor,
43+
}
44+
45+
def __init__(self):
46+
super().__init__(self.scalar_to_tensor_ops)
47+
48+
49+
class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass):
50+
scalar_to_tensor_ops = _common_ops
51+
52+
def __init__(self):
53+
super().__init__(self.scalar_to_tensor_ops)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def is_node_supported(
198198
exir_ops.edge.aten.clone.default,
199199
exir_ops.edge.aten.unsqueeze_copy.default,
200200
exir_ops.edge.aten.squeeze_copy.dims,
201+
exir_ops.edge.aten.pow.Tensor_Scalar,
202+
exir_ops.edge.aten.pow.Tensor_Tensor,
201203
operator.getitem,
202204
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
203205
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
op_minimum,
3333
op_mul,
3434
op_permute,
35+
op_pow,
3536
op_reciprocal,
3637
op_repeat,
3738
op_rescale,

backends/arm/operators/op_pow.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import serializer.tosa_serializer as ts
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
17+
from serializer.tosa_serializer import TosaOp
18+
from torch.fx import Node
19+
20+
21+
@register_node_visitor
22+
class PowVisitor_080_MI(NodeVisitor):
23+
target = "aten.pow.Tensor_Tensor"
24+
25+
tosa_specs = [
26+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
27+
]
28+
29+
def __init__(self, *args):
30+
super().__init__(*args)
31+
32+
def define_node(
33+
self,
34+
node: Node,
35+
tosa_graph: ts.TosaSerializer,
36+
inputs: List[TosaArg],
37+
output: TosaArg,
38+
) -> None:
39+
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
40+
raise ValueError(
41+
"All inputs and outputs need same dtype."
42+
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}"
43+
)
44+
if inputs[0].dtype not in [ts.DType.FP32, ts.DType.FP16]:
45+
raise ValueError(
46+
f"All inputs need to be FP32 or FP16. Got {inputs[0].dtype}"
47+
)
48+
49+
tosa_graph.addOperator(
50+
TosaOp.Op().POW,
51+
[
52+
inputs[0].name,
53+
inputs[1].name,
54+
],
55+
[output.name],
56+
None,
57+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def _match_pattern(
139139
torch.ops.aten.hardswish.default,
140140
torch.ops.aten.hardswish_.default,
141141
torch.ops.aten.full_like.default,
142+
torch.ops.aten.pow.Tensor_Scalar,
142143
]
143144

144145
_one_to_one_shared_input_qspec = [

0 commit comments

Comments
 (0)