Skip to content

Commit 14ff52f

Browse files
authored
Arm: Move ReplaceScalarWithTensorArgPass to transforms (#8519)
The pass is general and can be used by multiple backends. Use it in Arm backend and make small adjustments to make it work. Signed-off-by: Erik Lundell <[email protected]>
1 parent a7b9697 commit 14ff52f

File tree

7 files changed

+123
-65
lines changed

7 files changed

+123
-65
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@
7777
)
7878
from executorch.backends.arm.tosa_specification import TosaSpecification
7979

80+
from executorch.backends.transforms.replace_scalar_with_tensor import (
81+
ReplaceScalarWithTensorArgPass,
82+
)
8083
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
8184
from executorch.exir import ExportedProgram
8285
from executorch.exir.pass_manager import PassManager
@@ -102,6 +105,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
102105
self.add_pass(ConvertMeanDimToAveragePoolPass())
103106
self.add_pass(ConvertFullLikeToFullPass())
104107

108+
self.add_pass(ReplaceScalarWithTensorArgPass())
105109
self.add_pass(AnnotateDecomposedMatmulPass())
106110
self.add_pass(QuantizeOperatorArguments())
107111
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -125,7 +129,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
125129
return self._transform(exported_program.graph_module)
126130

127131
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
128-
132+
self.add_pass(ReplaceScalarWithTensorArgPass())
129133
self.add_pass(FuseQuantizedActivationPass())
130134
self.add_pass(RemoveGetItemPass())
131135
self.add_pass(ConvertSplitToSlicePass())
@@ -176,6 +180,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
176180

177181
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
178182
self.add_pass(ScalarsToAttributePass())
183+
self.add_pass(ReplaceScalarWithTensorArgPass())
179184
self.add_pass(DecomposeLayerNormPass())
180185
self.add_pass(DecomposeVarPass())
181186
self.add_pass(DecomposeMeanDimPass())

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
113113
exir_ops.edge.aten.le.Tensor,
114114
exir_ops.edge.aten.lt.Tensor,
115115
exir_ops.edge.aten.mul.Tensor,
116+
exir_ops.edge.aten.add.Scalar,
117+
exir_ops.edge.aten.sub.Scalar,
118+
exir_ops.edge.aten.mul.Scalar,
119+
exir_ops.edge.aten.div.Scalar,
116120
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
117121
exir_ops.edge.aten.native_layer_norm.default,
118122
exir_ops.edge.aten.sigmoid.default,

backends/arm/test/models/test_conformer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,14 @@ class TestConformer(unittest.TestCase):
3232
ops_after_partitioner = {
3333
"executorch_exir_dialects_edge__ops_aten_arange_start_step": 1,
3434
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
35-
"executorch_exir_dialects_edge__ops_aten_mul_Scalar": 4,
3635
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
3736
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
3837
"executorch_exir_dialects_edge__ops_aten_logical_not_default": 4,
3938
"executorch_exir_dialects_edge__ops_aten_any_dim": 2,
4039
"torch.ops.aten._assert_scalar.default": 10,
4140
"torch.ops.aten._local_scalar_dense.default": 1,
4241
"torch.ops.aten.scalar_tensor.default": 2,
43-
"torch.ops.higher_order.executorch_call_delegate": 5,
42+
"torch.ops.higher_order.executorch_call_delegate": 4,
4443
}
4544

4645
dim = 16

backends/arm/test/ops/test_scalars.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
import unittest
27

38
import torch
@@ -50,6 +55,22 @@ class Mul(torch.nn.Module):
5055
def forward(self, x, y):
5156
return x * y
5257

58+
class MulScalar(torch.nn.Module):
59+
def forward(self, x, y):
60+
return torch.ops.aten.mul.Scalar(x, y)
61+
62+
class DivScalar(torch.nn.Module):
63+
def forward(self, x, y):
64+
return torch.ops.aten.div.Scalar(x, y)
65+
66+
class AddScalar(torch.nn.Module):
67+
def forward(self, x, y):
68+
return torch.ops.aten.add.Scalar(x, y)
69+
70+
class SubScalar(torch.nn.Module):
71+
def forward(self, x, y):
72+
return torch.ops.aten.sub.Scalar(x, y)
73+
5374
class AddInplace(torch.nn.Module):
5475
def forward(self, x, y):
5576
x += y
@@ -91,6 +112,10 @@ def forward(self, x):
91112
("Sub_", SubInplace()),
92113
("Mul_", MulInplace()),
93114
("Div_", DivInplace()),
115+
("MulScalar", MulScalar()),
116+
("DivScalar", DivScalar()),
117+
("AddScalar", AddScalar()),
118+
("SubScalar", SubScalar()),
94119
]
95120

96121
const_ops = [("Add", AddConst())]
@@ -108,8 +133,8 @@ def forward(self, x):
108133
scalar = dtype[1]
109134
tensor_scalar_tests.append((test_name + "_ts", op[1], tensor, scalar))
110135

111-
# Don't add (scalar, tensor) test case for inplace ops.
112-
if op[0][-1] == "_":
136+
# Don't add (scalar, tensor) test case for inplace and .Scalar ops.
137+
if op[0][-1] == "_" or op[0][-6:] == "Scalar":
113138
continue
114139

115140
# sub(scalar, tensor) does not work in any case.

backends/arm/tosa_mapping.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# the standardised TOSA representation.
1212
#
1313

14+
from typing import Sequence
15+
1416
import serializer.tosa_serializer as ts # type: ignore
1517
import torch
1618

@@ -99,7 +101,7 @@ def __init__(self, argument) -> None:
99101
if isinstance(argument, torch.fx.Node):
100102
self.__process_node(argument)
101103
return
102-
if isinstance(argument, list):
104+
if isinstance(argument, Sequence):
103105
self.__process_list(argument)
104106
return
105107
if isinstance(argument, (int, float)):

backends/cadence/aot/replace_ops.py

Lines changed: 7 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -37,6 +38,9 @@
3738
)
3839
from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
3940
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
41+
from executorch.backends.transforms.replace_scalar_with_tensor import (
42+
ReplaceScalarWithTensorArgPass,
43+
)
4044
from executorch.exir.dialects._ops import ops as exir_ops
4145
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
4246
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
@@ -1713,65 +1717,9 @@ def call_operator(self, op, args, kwargs, meta):
17131717
)
17141718

17151719

1716-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1717-
class ReplaceScalarWithTensorArgPass(ExportPass):
1718-
"""
1719-
For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar,
1720-
replace the scalar arg with Tensor arg.
1721-
"""
1722-
1723-
scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
1724-
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
1725-
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
1726-
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
1727-
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
1728-
}
1729-
1730-
def get_replacement(self, op, args, kwargs, meta):
1731-
return super().call_operator(
1732-
# Replace with .Tensor variant.
1733-
op=self.scalar_to_tensor_ops[op],
1734-
args=(
1735-
# Tensor arg.
1736-
args[0],
1737-
# Scalar arg - replace with aten.full tensor.
1738-
super().call_operator(
1739-
exir_ops.edge.aten.full.default,
1740-
args=(
1741-
(1,),
1742-
args[1],
1743-
),
1744-
kwargs={"dtype": args[0].to_tensor().dtype},
1745-
meta=meta,
1746-
),
1747-
# Other args.
1748-
*args[2:],
1749-
),
1750-
kwargs=kwargs,
1751-
meta=meta,
1752-
)
1753-
1754-
def call_operator(self, op, args, kwargs, meta):
1755-
if op not in self.scalar_to_tensor_ops:
1756-
return super().call_operator(op, args, kwargs, meta)
1757-
1758-
# There must be exactly 2 args (3 for add and sub containing alpha)
1759-
assert len(args) == 2 or len(args) == 3
1760-
1761-
# If there are two args, just replace the op.
1762-
if len(args) == 2:
1763-
return self.get_replacement(op, args, kwargs, meta)
1764-
1765-
# In case the op has three args, it must be scalar add/sub op.
1766-
if (
1767-
op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar}
1768-
or "alpha" in kwargs
1769-
):
1770-
return super().call_operator(op, args, kwargs, meta)
1771-
1772-
return self.get_replacement(op, args, kwargs, meta)
1773-
1774-
1720+
@register_cadence_pass(CadencePassAttribute(opt_level=0))(
1721+
ReplaceScalarWithTensorArgPass()
1722+
)
17751723
@register_cadence_pass(CadencePassAttribute(opt_level=0))
17761724
class ReplaceScalarTensorWithFullPass(ExportPass):
17771725
"""
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import Dict
9+
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
13+
from executorch.exir.pass_base import ExportPass
14+
15+
16+
class ReplaceScalarWithTensorArgPass(ExportPass):
17+
"""
18+
For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar,
19+
replace the scalar arg with Tensor arg.
20+
"""
21+
22+
scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
23+
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
24+
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
25+
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
26+
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
27+
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
28+
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
29+
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
30+
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
31+
}
32+
33+
def get_replacement(self, op, args, kwargs, meta):
34+
return super().call_operator(
35+
# Replace with .Tensor variant.
36+
op=self.scalar_to_tensor_ops[op],
37+
args=(
38+
# Tensor arg.
39+
args[0],
40+
# Scalar arg - replace with aten.full tensor.
41+
super().call_operator(
42+
exir_ops.edge.aten.full.default,
43+
args=(
44+
(1,),
45+
args[1],
46+
),
47+
kwargs={"dtype": args[0].to_tensor().dtype},
48+
meta=meta,
49+
),
50+
# Other args.
51+
*args[2:],
52+
),
53+
kwargs=kwargs,
54+
meta=meta,
55+
)
56+
57+
def call_operator(self, op, args, kwargs, meta):
58+
if op not in self.scalar_to_tensor_ops:
59+
return super().call_operator(op, args, kwargs, meta)
60+
61+
# There must be exactly 2 args (3 for add and sub containing alpha)
62+
assert len(args) == 2 or len(args) == 3
63+
64+
# If there are two args, just replace the op.
65+
if len(args) == 2:
66+
return self.get_replacement(op, args, kwargs, meta)
67+
68+
# In case the op has three args, it must be scalar add/sub op.
69+
if (
70+
op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar}
71+
or "alpha" in kwargs
72+
):
73+
return super().call_operator(op, args, kwargs, meta)
74+
75+
return self.get_replacement(op, args, kwargs, meta)

0 commit comments

Comments
 (0)