Skip to content

Commit 9f56b40

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Add POW operator
Implement support for torch.pow in the MI and BI profile of TOSA. For MI, the operator works as Pytorch's reference implementation except for that the base operand cannot be a scalar but must be a tensor. This is due to a general bug. For BI, the exponent operand must be a scalar and a constant value. The base operand must be a tensor. Change-Id: I9c91b2a19ef43ae2ef884640974017824327dbf3
1 parent 8179aa3 commit 9f56b40

File tree

8 files changed

+322
-16
lines changed

8 files changed

+322
-16
lines changed

backends/arm/_passes/insert_table_ops.py

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@
66

77
# pyre-unsafe
88

9-
from typing import Callable, Dict
9+
from typing import Callable, cast, Dict, Set
1010

1111
import torch
1212
from executorch.backends.arm._passes.arm_pass_utils import create_node
1313
from executorch.backends.arm.tosa_quant_utils import QuantArgs
14+
from executorch.backends.transforms.utils import delete_constant_placeholder
1415
from executorch.exir import ExportedProgram
1516

1617
from executorch.exir.dialects._ops import ops as exir_ops
1718
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1819

1920
from executorch.exir.pass_base import ExportPass, PassResult
2021
from torch.fx import GraphModule
22+
from torch.fx.node import Node
2123
from torch.library import impl, Library
2224

2325
lib = Library("tosa", "DEF")
@@ -29,6 +31,59 @@ def _table_impl(*args, **kwargs): # pyre-ignore
2931
return args[0]
3032

3133

34+
class TableOps:
35+
"""
36+
Helper class for finding the corresponding table operator for a given Node.
37+
"""
38+
39+
def __init__(self, exported_program: ExportedProgram):
40+
self.exported_program = exported_program
41+
42+
# Targets that follow a straigtforward one-to-one mapping to their table op
43+
self.unary_table_ops: Dict[
44+
EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]
45+
] = {
46+
exir_ops.edge.aten.exp.default: torch.exp,
47+
exir_ops.edge.aten.floor.default: torch.floor,
48+
exir_ops.edge.aten.log.default: torch.log,
49+
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
50+
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
51+
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
52+
exir_ops.edge.aten.tanh.default: torch.tanh,
53+
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
54+
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
55+
}
56+
57+
# Targets that must be treated explicitly
58+
self.special_table_ops: Set[EdgeOpOverload] = {
59+
exir_ops.edge.aten.pow.Tensor_Tensor,
60+
}
61+
62+
def __contains__(self, node: Node) -> bool:
63+
return (
64+
node.target in self.unary_table_ops or node.target in self.special_table_ops
65+
)
66+
67+
def __getitem__(self, node: Node):
68+
target = cast(EdgeOpOverload, node.target)
69+
if target in self.unary_table_ops:
70+
return self.unary_table_ops[target]
71+
elif target in self.special_table_ops:
72+
match target:
73+
case exir_ops.edge.aten.pow.Tensor_Tensor:
74+
# Exponent is a constant. Retrieve it from the graph and embed it into a lambda.
75+
exp_node = cast(Node, node.args[1])
76+
exp_name = self.exported_program.graph_signature.inputs_to_buffers[
77+
exp_node.name
78+
]
79+
exp = self.exported_program.state_dict[exp_name]
80+
return lambda x: torch.pow(x, exp).flatten()
81+
case _:
82+
raise NotImplementedError("Unhandled table operation")
83+
else:
84+
raise KeyError("Table op for {target} does not exist")
85+
86+
3287
class InsertTableOpsPass(ExportPass):
3388
"""
3489
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
@@ -37,21 +92,10 @@ class InsertTableOpsPass(ExportPass):
3792
which will be used to produce the table values in operators/op_table.py.
3893
"""
3994

40-
table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
41-
exir_ops.edge.aten.exp.default: torch.exp,
42-
exir_ops.edge.aten.floor.default: torch.floor,
43-
exir_ops.edge.aten.log.default: torch.log,
44-
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
45-
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
46-
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
47-
exir_ops.edge.aten.tanh.default: torch.tanh,
48-
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
49-
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
50-
}
51-
5295
def __init__(self, exported_program: ExportedProgram) -> None:
5396
super().__init__()
5497
self.exported_program = exported_program
98+
self.table_ops = TableOps(exported_program)
5599

56100
def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
57101
"""
@@ -86,7 +130,7 @@ def f(x: torch.Tensor) -> torch.Tensor:
86130
def call(self, graph_module: GraphModule) -> PassResult:
87131
modified = False
88132
for node in graph_module.graph.nodes:
89-
if node.op != "call_function" or node.target not in self.table_ops:
133+
if node.op != "call_function" or node not in self.table_ops:
90134
continue
91135
input_qparams = node.meta["input_qparams"]
92136
output_qparams = node.meta["output_qparams"]
@@ -104,7 +148,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
104148
assert len(output_qparams) == 1
105149
# Generate table buffer
106150
buffer = self.generate_table_values(
107-
torch_op=self.table_ops[node.target],
151+
torch_op=self.table_ops[node],
108152
in_quantargs=input_qparams[0],
109153
out_quantargs=output_qparams[0],
110154
)
@@ -115,7 +159,19 @@ def call(self, graph_module: GraphModule) -> PassResult:
115159
buffer_name=table_node.name.replace("_default", ""), buffer=buffer
116160
)
117161
node.replace_all_uses_with(table_node)
118-
graph_module.graph.erase_node(node)
162+
163+
if node.target in self.table_ops.special_table_ops:
164+
# The node must be treated explicitly
165+
match node.target:
166+
case exir_ops.edge.aten.pow.Tensor_Tensor:
167+
exp_node = node.args[1]
168+
graph_module.graph.erase_node(node)
169+
delete_constant_placeholder(self.exported_program, exp_node)
170+
case _:
171+
raise NotImplementedError("Unhandled table operation")
172+
else:
173+
graph_module.graph.erase_node(node)
174+
119175
table_node.meta["input_qparams"] = input_qparams
120176
table_node.meta["output_qparams"] = output_qparams
121177
modified = True

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self, exported_program):
4545
exir_ops.edge.aten.sub.Tensor,
4646
exir_ops.edge.aten.mul.Tensor,
4747
exir_ops.edge.aten.div.Tensor,
48+
exir_ops.edge.aten.pow.Tensor_Tensor,
4849
]
4950

5051
def _match_op_rank(self, graph_module, node, arg, max_rank):

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ def is_node_supported(
194194
exir_ops.edge.aten.clone.default,
195195
exir_ops.edge.aten.unsqueeze_copy.default,
196196
exir_ops.edge.aten.squeeze_copy.dims,
197+
exir_ops.edge.aten.pow.Tensor_Scalar,
198+
exir_ops.edge.aten.pow.Tensor_Tensor,
197199
operator.getitem,
198200
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
199201
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
op_minimum,
3232
op_mul,
3333
op_permute,
34+
op_pow,
3435
op_reciprocal,
3536
op_repeat,
3637
op_rescale,

backends/arm/operators/op_pow.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import serializer.tosa_serializer as ts
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
17+
from serializer.tosa_serializer import TosaOp
18+
from torch.fx import Node
19+
20+
21+
@register_node_visitor
22+
class PowVisitor_080_MI(NodeVisitor):
23+
target = "aten.pow.Tensor_Tensor"
24+
25+
tosa_specs = [
26+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
27+
]
28+
29+
def __init__(self, *args):
30+
super().__init__(*args)
31+
32+
def define_node(
33+
self,
34+
node: Node,
35+
tosa_graph: ts.TosaSerializer,
36+
inputs: List[TosaArg],
37+
output: TosaArg,
38+
) -> None:
39+
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
40+
raise ValueError(
41+
"All inputs and outputs need same dtype."
42+
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}"
43+
)
44+
if inputs[0].dtype not in [ts.DType.FP32, ts.DType.FP16]:
45+
raise ValueError(
46+
f"All inputs need to be FP32 or FP16. Got {inputs[0].dtype}"
47+
)
48+
49+
tosa_graph.addOperator(
50+
TosaOp.Op().POW,
51+
[
52+
inputs[0].name,
53+
inputs[1].name,
54+
],
55+
[output.name],
56+
None,
57+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def _match_pattern(
137137
torch.ops.aten.hardsigmoid.default,
138138
torch.ops.aten.hardswish.default,
139139
torch.ops.aten.full_like.default,
140+
torch.ops.aten.pow.Tensor_Tensor,
140141
]
141142

142143
_one_to_one_shared_input_qspec = [

0 commit comments

Comments
 (0)