Skip to content

Commit 8d11ea6

Browse files
authored
Arm backend: Convert int pow to multiplications (#11037)
### Summary Add a pass to convert integer pow to series of multiplications to handle square operations on negative values since TOSA 1.0 only allows values > 0 for its POW operation. ### Test plan Test on internal and external CI. Signed-off-by: Per Åstrand <[email protected]>
1 parent d8c26ee commit 8d11ea6

File tree

5 files changed

+132
-1
lines changed

5 files changed

+132
-1
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
1616
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
1717
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
18+
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
1819
from .convert_minmax_pass import ConvertMinMaxPass # noqa
1920
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2021
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ConvertAnyDefaultDimDimsPass,
1919
ConvertExpandCopyToRepeatPass,
2020
ConvertFullLikeToFullPass,
21+
ConvertIntPowToMuls,
2122
ConvertMinMaxPass,
2223
ConvertMmToBmmPass,
2324
ConvertSplitToSlicePass,
@@ -131,14 +132,14 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
131132

132133
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
133134
self.add_pass(DecomposeSqrtPass())
135+
self.add_pass(ConvertIntPowToMuls())
134136
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
135137
self.add_pass(FuseQuantizedActivationPass())
136138
self.add_pass(RemoveGetItemPass())
137139
self.add_pass(ConvertSplitToSlicePass())
138140
self.add_pass(FuseBatchnorm2DPass(exported_program))
139141
self.add_pass(ConvertMmToBmmPass())
140142
self.add_pass(DecomposeLinearPass())
141-
self.add_pass(DecomposeLinearVectorNormPass())
142143
self.add_pass(DecomposeLeakyReLUPass())
143144
self.add_pass(DecomposeBatchNormPass())
144145
self.add_pass(DecomposeLayerNormPass())
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from executorch.backends.arm._passes import ArmPass
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
12+
class ConvertIntPowToMuls(ArmPass):
13+
"""
14+
Replaces pow with integer exponent with a series of multiplications.
15+
Only handles pow.Tensor_Scalar and not pow.Tensor_Tensor.
16+
Needs to be run before doing scalar to tensor conversion.
17+
"""
18+
19+
def call_operator(self, op, args, kwargs, meta):
20+
if op != exir_ops.edge.aten.pow.Tensor_Scalar:
21+
return super().call_operator(op, args, kwargs, meta)
22+
23+
x = args[0]
24+
exp = args[1]
25+
26+
# Handle zero first and return early
27+
if exp == 0:
28+
# return a tensor of ones with the same shape as x
29+
return super().call_operator(
30+
exir_ops.edge.aten.full_like.default, (x, 1), {}, meta, True
31+
)
32+
33+
if not isinstance(exp, int):
34+
return super().call_operator(op, args, kwargs, meta)
35+
36+
# Handle negative exponent
37+
if exp < 0:
38+
x = super().call_operator(
39+
exir_ops.edge.aten.reciprocal.default, (x,), {}, meta, True
40+
)
41+
exp = -exp
42+
43+
res = x
44+
45+
# Consider exponentiation by squaring, if exp turns out to be large.
46+
# Now we just roll out the multiplications.
47+
for _ in range(exp - 1):
48+
res = super().call_operator(
49+
exir_ops.edge.aten.mul.Tensor, (res, x), {}, meta, True
50+
)
51+
52+
return res

backends/arm/test/ops/test_pow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ class Pow_TensorScalar(torch.nn.Module):
7171
torch.abs(torch.randn((1, 2, 3, 6))),
7272
6.789,
7373
),
74+
"neg_base_exp_pos_integer": lambda: (
75+
-torch.abs(torch.randn((1, 2, 3, 6))) - 10,
76+
3,
77+
),
7478
}
7579

7680
def __init__(self, exp):
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes import ConvertIntPowToMuls
10+
11+
from executorch.backends.arm.test import common
12+
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.nn.Module, int] # Input x
16+
17+
18+
class Square(torch.nn.Module):
19+
"""
20+
Basic squaring
21+
"""
22+
23+
def forward(self, x):
24+
return x.square()
25+
26+
def get_inputs(self) -> input_t:
27+
return (torch.rand(4, 4),)
28+
29+
30+
class Pow(torch.nn.Module):
31+
"""
32+
Basic squaring
33+
"""
34+
35+
def __init__(self, exponent):
36+
super().__init__()
37+
self.exponent = exponent
38+
39+
def forward(self, x):
40+
return x.pow(self.exponent)
41+
42+
def get_inputs(self) -> input_t:
43+
return (torch.rand(4, 4),)
44+
45+
46+
test_data = {
47+
"square": (Square(), 1),
48+
"pow_2": (Pow(2), 1),
49+
"pow_3": (Pow(3), 2),
50+
"pow_0": (Pow(0), 0),
51+
"pow_neg_2": (Pow(-2), 1),
52+
}
53+
54+
55+
@common.parametrize("data", test_data)
56+
def test_convert_pow_to_muls(data):
57+
module = data[0]
58+
nbr_muls = data[1]
59+
pipeline = PassPipeline[input_t](
60+
module,
61+
module.get_inputs(),
62+
quantize=False,
63+
ops_before_pass={
64+
"executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 1,
65+
},
66+
ops_not_before_pass=[],
67+
ops_after_pass={
68+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": nbr_muls,
69+
},
70+
ops_not_after_pass=["executorch_exir_dialects_edge__ops_pow_Tensor_Scalar"],
71+
pass_list=[ConvertIntPowToMuls],
72+
)
73+
pipeline.run()

0 commit comments

Comments
 (0)