Skip to content

Commit dd7fa6a

Browse files
Add ArmPassManager (#3749)
Summary: - Add ArmPassManager that's responsible for pass handling. - Add RemoveClones-pass. - Add initial pass test. Change-Id: I71e1e0787d0788a835a412c608ae75331fe65cc2 Pull Request resolved: #3749 Reviewed By: mergennachin Differential Revision: D59260418 Pulled By: digantdesai fbshipit-source-id: 7c75ebfce125ddd53ceefd4c14dc280eba9098ec
1 parent 4b45264 commit dd7fa6a

File tree

10 files changed

+190
-93
lines changed

10 files changed

+190
-93
lines changed

backends/arm/README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,19 @@ ethos-u-vela compilation stack. which follows the fully AoT flow.
1515
## Layout
1616

1717
Export:
18-
- `arm_backend.py` - Main entrypoint for the ArmPartitioner and ArmBackend. For more information see the section on [Arm Bac
19-
kend Architecture](#arm-backend-architecture). For examples of use see `executorch/examples/arm`.
18+
- `arm_backend.py` - Main entrypoint for the ArmPartitioner and ArmBackend. For more information see the section on
19+
[Arm Backend Architecture](#arm-backend-architecture). For examples of use see `executorch/examples/arm`.
2020
- `tosa_mapping.py` - utilities for mapping edge dialect to TOSA
2121
- `tosa_quant_utils.py` - utilities for mapping quantization information to TOSA encoding
2222

23+
Operators:
24+
- `node_visitor.py` - Base class for edge operator lowering
25+
- `op_*.py` - Edge operator lowering/serialization to TOSA
26+
27+
Passes:
28+
- `arm_pass_manager.py` - Pass manager. Will decide which passes need to be applied depending on the compile_spec.
29+
- `*_pass.py` - Compiler passes derived from ExportPass
30+
2331
Quantization:
2432
- `arm_quantizer.py` - Quantizer for Arm backend
2533
- `arm_quantizer_utils.py` - Utilities for quantization
@@ -36,8 +44,10 @@ This is the structure of the test directory
3644

3745
```
3846
test # Root test folder
47+
├── misc # Testing of debug features
3948
├── models # Full model tests
4049
├── ops # Single op tests
50+
├── passes # Compiler passes tests
4151
├── tester # Arm Tester class
4252
├── tosautil # Utility functions for TOSA artifacts
4353
├ common.py # Common functions and definitions used by many tests

backends/arm/arm_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.backends.arm.arm_vela import vela_compile
1818
from executorch.backends.arm.operators.node_visitor import get_node_visitors
1919
from executorch.backends.arm.operators.op_placeholder import process_placeholder
20+
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
2021
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
2122
from executorch.backends.arm.tosa_quant_utils import get_quant_node_dtype, is_quant_node
2223
from executorch.backends.arm.tosa_utils import (
@@ -241,10 +242,13 @@ def preprocess( # noqa: C901
241242
# Converted output for this subgraph, serializer needs path early as it emits
242243
# const data directly. Path created and data written only in debug builds.
243244
tosa_graph = ts.TosaSerializer(artifact_path)
245+
graph_module = ArmPassManager().transform_to_backend_pipeline(
246+
graph_module=edge_program.graph_module, compile_spec=compile_spec
247+
)
244248

245249
node_visitors = get_node_visitors(edge_program)
246250

247-
for node in edge_program.graph.nodes:
251+
for node in graph_module.graph.nodes:
248252
if node.op == "call_function":
249253
# Unpack arguments and convert
250254
inputs = []

backends/arm/arm_partitioner.py

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from executorch.backends.arm.arm_backend import ArmBackend
13+
from executorch.backends.arm.passes.tag_io_quant_pass import TagIOQuantPass
1314
from executorch.exir.backend.compile_spec_schema import CompileSpec
1415
from executorch.exir.backend.partitioner import (
1516
DelegationSpec,
@@ -18,6 +19,7 @@
1819
)
1920
from executorch.exir.backend.utils import tag_constant_data
2021
from executorch.exir.dialects._ops import ops as exir_ops
22+
from executorch.exir.passes import PassManager
2123
from torch.export.exported_program import ExportedProgram
2224
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2325

@@ -54,9 +56,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5456
supported &= self.is_node_supported_custom(node)
5557

5658
# Override partitioning based on pre partition passes
57-
if supported and "arm_partition" in node.meta:
58-
supported = supported & node.meta["arm_partition"]
59-
node.meta.pop("arm_partition")
59+
if "arm_override_partition" in node.meta:
60+
supported = supported & node.meta["arm_override_partition"]
61+
node.meta.pop("arm_override_partition")
6062

6163
return supported
6264

@@ -69,54 +71,6 @@ def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
6971
return True
7072

7173

72-
from executorch.exir.pass_base import ExportPass, PassResult
73-
from executorch.exir.passes import PassManager
74-
75-
76-
class TagIOQuant(ExportPass):
77-
"""
78-
Pass run before partitioning to tag Q/DQ on any placeholder and output
79-
to ensure we don't greedily partition them for device. Float conversion
80-
has to happen outside a TOSA base inference profile.
81-
"""
82-
83-
def __init__(self, edge_program: torch.export.ExportedProgram):
84-
super(TagIOQuant, self).__init__()
85-
self.edge_program = edge_program
86-
87-
def is_quant_node(self, node: torch.fx.node.Node):
88-
return node.target in {
89-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
90-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
91-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
92-
}
93-
94-
def is_dequant_node(self, node: torch.fx.node.Node):
95-
return node.target in {
96-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
97-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
98-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
99-
}
100-
101-
def call(self, graph_module: torch.fx.GraphModule):
102-
for node in graph_module.graph.nodes:
103-
# tag q of input
104-
if node.op == "placeholder":
105-
for user in node.users.keys():
106-
# if we have an input going into a quantize
107-
if self.is_quant_node(user):
108-
user.meta["arm_partition"] = False
109-
110-
# tag dq of outputs
111-
if node.op == "output":
112-
quant, *_ = node.args[0]
113-
if self.is_dequant_node(quant):
114-
quant.meta["arm_partition"] = False
115-
116-
graph_module.recompile()
117-
return PassResult(graph_module, True)
118-
119-
12074
@final
12175
class ArmPartitioner(Partitioner):
12276
def __init__(self, compile_spec: List[CompileSpec]) -> None:
@@ -133,7 +87,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
13387
# Exclude IO quantization from the partition
13488
passes = PassManager(
13589
passes=[
136-
TagIOQuant(exported_program),
90+
TagIOQuantPass(),
13791
]
13892
)
13993
passes(exported_program.graph_module)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -9,7 +9,6 @@
99
op_addmm,
1010
op_avg_pool2d,
1111
op_batch_norm,
12-
op_clone,
1312
op_conv2d,
1413
op_dequant,
1514
op_div,

backends/arm/operators/op_clone.py

Lines changed: 0 additions & 34 deletions
This file was deleted.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import torch
9+
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
10+
from executorch.exir.backend.compile_spec_schema import CompileSpec
11+
from executorch.exir.pass_manager import PassManager
12+
13+
14+
class ArmPassManager(PassManager):
15+
16+
def _transform(self, graph_module: torch.fx.Graph):
17+
return self(graph_module).graph_module
18+
19+
def transform_to_backend_pipeline(
20+
self, graph_module: torch.fx.Graph, compile_spec: CompileSpec
21+
):
22+
"""Apply passes before transforming program to backend"""
23+
self.add_pass(RemoveClonePass())
24+
25+
return self._transform(graph_module)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class RemoveClonePass(ExportPass):
13+
14+
def call(self, graph_module: torch.fx.GraphModule):
15+
for node in graph_module.graph.nodes:
16+
if node.op != "call_function":
17+
continue
18+
if node.target == exir_ops.edge.aten.clone.default:
19+
for user in list(node.users):
20+
# TODO remove dq/q-ops around removed clone-op
21+
user.replace_input_with(node, node.args[0])
22+
graph_module.graph.erase_node(node)
23+
graph_module.graph.eliminate_dead_code()
24+
graph_module.recompile()
25+
return PassResult(graph_module, True)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class TagIOQuantPass(ExportPass):
13+
"""
14+
Pass run before partitioning to tag Q/DQ on any placeholder and output
15+
to ensure we don't greedily partition them for device. Float conversion
16+
has to happen outside a TOSA base inference profile.
17+
"""
18+
19+
def is_quant_node(self, node: torch.fx.node.Node):
20+
return node.target in {
21+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
22+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
23+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
24+
}
25+
26+
def is_dequant_node(self, node: torch.fx.node.Node):
27+
return node.target in {
28+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
29+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
30+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
31+
}
32+
33+
def call(self, graph_module: torch.fx.GraphModule):
34+
for node in graph_module.graph.nodes:
35+
# tag q of input
36+
if node.op == "placeholder":
37+
for user in node.users.keys():
38+
# if we have an input going into a quantize
39+
if self.is_quant_node(user):
40+
user.meta["arm_override_partition"] = False
41+
42+
# tag dq of outputs
43+
if node.op == "output":
44+
quant, *_ = node.args[0]
45+
if self.is_dequant_node(quant):
46+
quant.meta["arm_override_partition"] = False
47+
48+
graph_module.recompile()
49+
return PassResult(graph_module, True)

backends/arm/test/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def get_tosa_compile_spec(permute_memory_to_nhwc=False, custom_path=None):
102102
return compile_spec
103103

104104

105-
def get_u55_compile_spec(permute_memory_to_nhwc=False, custom_path=None):
105+
def get_u55_compile_spec(
106+
permute_memory_to_nhwc=False, quantize_io=False, custom_path=None
107+
):
106108
"""
107109
Default compile spec for Ethos-U55 tests.
108110
"""
@@ -115,7 +117,7 @@ def get_u55_compile_spec(permute_memory_to_nhwc=False, custom_path=None):
115117
memory_mode="Shared_Sram",
116118
extra_flags=None,
117119
)
118-
.set_quantize_io(is_option_enabled("quantize_io"))
120+
.set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
119121
.set_permute_memory_format(permute_memory_to_nhwc)
120122
.dump_intermediate_artifacts_to(artifact_path)
121123
.build()
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
14+
15+
class Add(torch.nn.Module):
16+
17+
def get_inputs(self):
18+
return (torch.rand(1, 10, 10, 10),)
19+
20+
def forward(self, x):
21+
return x + x
22+
23+
24+
class TestTagIOQuantPass(unittest.TestCase):
25+
26+
def _tosa_BI_u55_pipeline(self, module: torch.nn.Module):
27+
(
28+
ArmTester(
29+
module,
30+
example_inputs=module.get_inputs(),
31+
compile_spec=common.get_u55_compile_spec(quantize_io=True),
32+
)
33+
.quantize()
34+
.export()
35+
.to_edge()
36+
.check_count(
37+
{
38+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2
39+
}
40+
)
41+
.check_count(
42+
{
43+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2
44+
}
45+
)
46+
.partition()
47+
.check_count(
48+
{
49+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1
50+
}
51+
)
52+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
53+
.check_count(
54+
{
55+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1
56+
}
57+
)
58+
# .to_executorch() requires additional steps
59+
)
60+
61+
def test_BI_u55_artifact(self):
62+
model = Add()
63+
self._tosa_BI_u55_pipeline(model)

0 commit comments

Comments
 (0)