Skip to content

Commit 6669e18

Browse files
Erik-Lundellfacebook-github-bot
authored andcommitted
Fix scalar arithemetic and add test cases (#6224)
Summary: Add UnsquezeScalarPlaceholders pass to make scalars rank 1 Add MatchShapesPass to guarantee same rank for all inputs for ops that require it. Additional fixes to make Scalar tests pass Map which cases work and which don't. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I4ea5e189e26cf7aff391ec153d525b2fb61aa16f Fix shape issues Change-Id: I0b8588cd5f8b284c25e806bb83bc788067d5b649 Pull Request resolved: #6224 Reviewed By: mergennachin Differential Revision: D64427014 Pulled By: digantdesai fbshipit-source-id: 5295e9ffab1d848b111e0cb01aa0ce9142c20781
1 parent 5f12f28 commit 6669e18

12 files changed

+476
-32
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import cast
1010

1111
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1213
from executorch.backends.arm.tosa_quant_utils import dq_op
1314
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1415
from executorch.exir.pass_base import ExportPass, PassResult
@@ -52,12 +53,7 @@ def call(self, graph_module: torch.fx.GraphModule):
5253
NHWC_Order = (0, 2, 3, 1)
5354
HWCM_Order = (2, 3, 0, 1)
5455
for node in graph_module.graph.nodes:
55-
if isinstance(
56-
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
57-
):
58-
node_data = node.meta["val"][0].data
59-
else:
60-
node_data = node.meta["val"].data
56+
node_data = get_first_fake_tensor(node).data
6157

6258
if len(node_data.shape) == 4:
6359
dim_order = NHWC_Order

backends/arm/_passes/arm_pass_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
2323
InsertSqueezeAfterSumPass,
2424
)
25+
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
2526
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
2627
ConvertMeanDimToAveragePool,
2728
)
@@ -30,6 +31,9 @@
3031
ScalarsToAttributePass,
3132
)
3233
from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
34+
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
35+
UnsqueezeScalarPlaceholdersPass,
36+
)
3337
from executorch.exir import ExportedProgram
3438
from executorch.exir.backend.compile_spec_schema import CompileSpec
3539
from executorch.exir.pass_manager import PassManager
@@ -45,10 +49,12 @@ def transform_to_backend_pipeline(
4549
):
4650
"""Apply passes before transforming program to backend"""
4751
self.add_pass(CastInt64ToInt32Pass(exported_program))
52+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
4853
self.add_pass(SizeAdjustConv2DPass())
4954
self.add_pass(RemoveClonePass())
5055
self.add_pass(ConvertExpandCopyToRepeatPass())
5156
self.add_pass(ConvertMeanDimToAveragePool())
57+
self.add_pass(MatchArgRanksPass(exported_program))
5258
self.add_pass(DecomposeDivPass())
5359
self.add_pass(InsertSqueezeAfterSumPass())
5460
self.add_pass(ConvertSplitToSlicePass())
@@ -61,6 +67,6 @@ def transform_to_backend_pipeline(
6167
return self._transform(exported_program.graph_module)
6268

6369
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
64-
self.add_pass(DecomposeDivPass())
6570
self.add_pass(ScalarsToAttributePass())
71+
self.add_pass(DecomposeDivPass())
6672
return self._transform(graph_module)

backends/arm/_passes/arm_pass_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from typing import Optional
88

99
import torch
10+
import torch.fx
1011

1112
from executorch.exir.dialects._ops import ops as exir_ops
1213
from torch._ops import OpOverload
14+
from torch._subclasses.fake_tensor import FakeTensor
1315

1416

1517
def create_node(
@@ -64,3 +66,21 @@ def insert_q_dq_pair(
6466
# node's first use
6567
q.args = (anchor,) + q_params
6668
return dq
69+
70+
71+
def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
72+
"""
73+
Returns a FakeTensor from the meta field of 'node'.
74+
If the node contains many fake tensors, return the first one.
75+
"""
76+
if isinstance(
77+
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
78+
):
79+
fake_tensor = node.meta["val"][0]
80+
else:
81+
fake_tensor = node.meta["val"]
82+
83+
assert isinstance(
84+
fake_tensor, FakeTensor
85+
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
86+
return fake_tensor

backends/arm/_passes/decompose_div_pass.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@
88
from executorch.exir.dialects._ops import ops as exir_ops
99
from executorch.exir.pass_base import ExportPass
1010

11+
edge_div_ops = (exir_ops.edge.aten.div.Tensor,)
12+
aten_div_ops = (torch.ops.aten.div.Tensor, torch.ops.aten.div_.Tensor)
13+
1114

1215
def get_div_decomposition(op) -> tuple:
1316
"""
1417
Returns the the (reciprocal_op, mul_op), where the ops depends on if
1518
the div op is in exir_ops torch.ops.aten.
1619
"""
17-
if op == exir_ops.edge.aten.div.Tensor:
20+
if op in edge_div_ops:
1821
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
19-
if op == torch.ops.aten.div.Tensor:
22+
if op in aten_div_ops:
2023
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
2124
raise RuntimeError(f"Can't get div decomposition for op {op}")
2225

@@ -33,7 +36,7 @@ class DecomposeDivPass(ExportPass):
3336
"""
3437

3538
def call_operator(self, op, args, kwargs, meta):
36-
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
39+
if op not in (edge_div_ops + aten_div_ops):
3740
return super().call_operator(op, args, kwargs, meta)
3841

3942
reciprocal_op, mul_op = get_div_decomposition(op)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import cast
9+
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
)
14+
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
from executorch.exir.pass_base import ExportPass, PassResult
18+
from torch.fx import GraphModule, Node
19+
20+
21+
class MatchArgRanksPass(ExportPass):
22+
"""
23+
For ops in 'targeted_ops', make sure that the inputs share the same rank.
24+
New dimensions are inserted at from the beginning of the
25+
"""
26+
27+
def __init__(self, exported_program):
28+
super().__init__()
29+
self.exported_program = exported_program
30+
31+
targeted_ops = [
32+
exir_ops.edge.aten.add.Tensor,
33+
exir_ops.edge.aten.sub.Tensor,
34+
exir_ops.edge.aten.mul.Tensor,
35+
exir_ops.edge.aten.div.Tensor,
36+
]
37+
38+
def _match_op_rank(self, graph_module, node, arg, max_rank):
39+
"""
40+
In graph_module, insert a view between arg and node to make the
41+
rank of arg match the other args to node.
42+
"""
43+
shape = get_first_fake_tensor(arg).shape
44+
rank = len(shape)
45+
new_shape = list([1] * (max_rank - rank) + list(shape))
46+
with graph_module.graph.inserting_before(node):
47+
view = create_node(
48+
graph_module.graph,
49+
exir_ops.edge.aten.view_copy.default,
50+
args=(arg, new_shape),
51+
kwargs={},
52+
)
53+
node.replace_input_with(arg, view)
54+
55+
def _match_buffer_rank(self, arg, max_rank):
56+
"""
57+
Change arg's fake tensor meta to match max_rank if:
58+
- arg is found in inputs_to_buffers or inputs_to_parameters.
59+
"""
60+
fake_tensor = get_first_fake_tensor(arg)
61+
shape = fake_tensor.shape
62+
rank = len(shape)
63+
new_shape = list([1] * (max_rank - rank) + list(shape))
64+
65+
buffer_name = None
66+
if arg.name in self.exported_program.graph_signature.inputs_to_buffers:
67+
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
68+
arg.name
69+
]
70+
elif arg.name in self.exported_program.graph_signature.inputs_to_parameters:
71+
buffer_name = self.exported_program.graph_signature.inputs_to_parameters[
72+
arg.name
73+
]
74+
if buffer_name:
75+
new_tensor = self.exported_program.state_dict[buffer_name].reshape(
76+
new_shape
77+
)
78+
self.exported_program.state_dict[buffer_name] = new_tensor
79+
arg.meta["val"] = fake_tensor.fake_mode.from_tensor(
80+
new_tensor, static_shapes=True
81+
)
82+
83+
def call(self, graph_module: GraphModule) -> PassResult:
84+
for node in graph_module.graph.nodes:
85+
node = cast(Node, node)
86+
87+
if node.op != "call_function" or node.target not in self.targeted_ops:
88+
continue
89+
90+
# Calculate max rank of all inputs to node
91+
max_rank = 1
92+
for arg in node.args:
93+
if isinstance(arg, Node):
94+
shape = get_first_fake_tensor(arg).shape
95+
max_rank = max(max_rank, len(shape))
96+
97+
# Adjust output shape of args if needed.
98+
for arg in node.args:
99+
if not isinstance(arg, Node):
100+
continue
101+
shape = get_first_fake_tensor(arg).shape
102+
rank = len(shape)
103+
if rank == max_rank:
104+
continue
105+
106+
# If the argument is call_function, match shape by inserting view node.
107+
if arg.op == "call_function":
108+
self._match_op_rank(graph_module, node, arg, max_rank)
109+
else:
110+
# If the argument is a buffer or parameter, adjust shape by changing the fake tensor meta.
111+
self._match_buffer_rank(arg, max_rank)
112+
113+
graph_module.recompile()
114+
graph_module = super().call(graph_module).graph_module
115+
return PassResult(graph_module, True)
116+
117+
def ensures(self, graph_module):
118+
for node in graph_module.graph.nodes:
119+
if node.op != "call_function" or node.target not in self.targeted_ops:
120+
continue
121+
arg0_rank = node.args[0].meta["val"].dim()
122+
arg1_rank = node.args[1].meta["val"].dim()
123+
if arg0_rank != arg1_rank:
124+
raise ValueError(
125+
"Arguments of arithmetic operators need to have the same rank!"
126+
)

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import cast, Union
88

99
import torch
10-
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
10+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1111

1212
from executorch.exir.pass_base import ExportPass, PassResult
1313
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
@@ -22,10 +22,14 @@ class ScalarsToAttributePass(ExportPass):
2222

2323
targeted_ops = [
2424
torch.ops.aten.add.Tensor,
25+
torch.ops.aten.add_.Tensor,
2526
torch.ops.aten.sub.Tensor,
2627
torch.ops.aten.sub_.Tensor,
28+
torch.ops.aten.rsub.Scalar,
2729
torch.ops.aten.mul.Tensor,
30+
torch.ops.aten.mul_.Tensor,
2831
torch.ops.aten.div.Tensor,
32+
torch.ops.aten.div_.Tensor,
2933
]
3034

3135
def call(self, graph_module: GraphModule) -> PassResult:
@@ -37,7 +41,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
3741
biggest_rank = 1
3842
for arg in n.args:
3943
if isinstance(arg, Node):
40-
_, shape, _ = extract_tensor_meta(arg.meta)
44+
shape = get_first_fake_tensor(arg).shape
4145
biggest_rank = max(biggest_rank, len(shape))
4246

4347
new_args = []
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
11+
class UnsqueezeScalarPlaceholdersPass(ExportPass):
12+
"""
13+
Placeholders that have node.meta["val"].shape = () cause issues later in the lowering.
14+
This pass unsqueezes the placeholders to make sure shape is at least (1,).
15+
"""
16+
17+
def __init__(self, exported_program):
18+
self.exported_program = exported_program
19+
super().__init__()
20+
21+
def call(self, graph_module: torch.fx.GraphModule):
22+
for node in graph_module.graph.nodes:
23+
if node.op != "placeholder":
24+
continue
25+
rank = node.meta["val"].dim()
26+
if rank == 0:
27+
if not (
28+
node.name in self.exported_program.graph_signature.inputs_to_buffers
29+
or node.name
30+
in self.exported_program.graph_signature.inputs_to_parameters
31+
):
32+
continue
33+
tensor = self.exported_program.state_dict[node.name]
34+
if tensor.dim() == 0:
35+
self.exported_program.state_dict[node.name] = tensor.unsqueeze(0)
36+
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
37+
tensor.unsqueeze(0), static_shapes=True
38+
)
39+
else:
40+
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
41+
tensor, static_shapes=True
42+
)
43+
44+
graph_module.recompile()
45+
graph_module = super().call(graph_module).graph_module
46+
return PassResult(graph_module, True)
47+
48+
def ensures(self, graph_module: torch.fx.GraphModule):
49+
for node in graph_module.graph.nodes:
50+
if node.op == "placeholder":
51+
rank = node.meta["val"].dim()
52+
if rank == 0:
53+
raise ValueError("Placeholders of rank 0 are not supported!")

backends/arm/quantizer/quantization_annotation/mul_annotator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _annotate_mul(
2424

2525
annotated_partitions = []
2626
for node in gm.graph.nodes:
27-
if node.target not in (torch.ops.aten.mul.Tensor,):
27+
if node.target not in (torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor):
2828
continue
2929
mul_node = node
3030
annotated_partitions.append([mul_node])

backends/arm/quantizer/quantization_annotation/sub_annotator.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
# pyre-unsafe
88

9-
import itertools
10-
import operator
119
from typing import Callable, List, Optional
1210

1311
import torch
@@ -16,7 +14,6 @@
1614
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1715
from torch.ao.quantization.quantizer import QuantizationAnnotation
1816
from torch.fx import GraphModule, Node
19-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
2017

2118

2219
@register_annotator("sub")
@@ -25,14 +22,12 @@ def _annotate_sub(
2522
quantization_config: QuantizationConfig,
2623
filter_fn: Optional[Callable[[Node], bool]] = None,
2724
) -> Optional[List[List[Node]]]:
28-
sub_partitions = get_source_partitions(
29-
gm.graph, [operator.sub, torch.sub, operator.isub], filter_fn
30-
)
31-
sub_partitions = list(itertools.chain.from_iterable(sub_partitions.values()))
3225
annotated_partitions = []
33-
for sub_partition in sub_partitions:
34-
annotated_partitions.append(sub_partition.nodes)
35-
sub_node = sub_partition.output_nodes[0]
26+
for node in gm.graph.nodes:
27+
if node.target not in (torch.ops.aten.sub.Tensor, torch.ops.aten.sub_.Tensor):
28+
continue
29+
annotated_partitions.append(node)
30+
sub_node = node
3631
if arm_quantizer_utils.is_annotated(sub_node):
3732
continue
3833

0 commit comments

Comments
 (0)