Skip to content

Commit 613e7a2

Browse files
committed
Update base for Update on "Use external_deps for sentencepiece"
as title Differential Revision: [D59770172](https://our.internmc.facebook.com/intern/diff/D59770172/) [ghstack-poisoned]
2 parents db09613 + fbe0af1 commit 613e7a2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+799
-677
lines changed

backends/arm/README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,19 @@ ethos-u-vela compilation stack. which follows the fully AoT flow.
1515
## Layout
1616

1717
Export:
18-
- `arm_backend.py` - Main entrypoint for the ArmPartitioner and ArmBackend. For more information see the section on [Arm Bac
19-
kend Architecture](#arm-backend-architecture). For examples of use see `executorch/examples/arm`.
18+
- `arm_backend.py` - Main entrypoint for the ArmPartitioner and ArmBackend. For more information see the section on
19+
[Arm Backend Architecture](#arm-backend-architecture). For examples of use see `executorch/examples/arm`.
2020
- `tosa_mapping.py` - utilities for mapping edge dialect to TOSA
2121
- `tosa_quant_utils.py` - utilities for mapping quantization information to TOSA encoding
2222

23+
Operators:
24+
- `node_visitor.py` - Base class for edge operator lowering
25+
- `op_*.py` - Edge operator lowering/serialization to TOSA
26+
27+
Passes:
28+
- `arm_pass_manager.py` - Pass manager. Will decide which passes need to be applied depending on the compile_spec.
29+
- `*_pass.py` - Compiler passes derived from ExportPass
30+
2331
Quantization:
2432
- `arm_quantizer.py` - Quantizer for Arm backend
2533
- `arm_quantizer_utils.py` - Utilities for quantization
@@ -36,8 +44,10 @@ This is the structure of the test directory
3644

3745
```
3846
test # Root test folder
47+
├── misc # Testing of debug features
3948
├── models # Full model tests
4049
├── ops # Single op tests
50+
├── passes # Compiler passes tests
4151
├── tester # Arm Tester class
4252
├── tosautil # Utility functions for TOSA artifacts
4353
├ common.py # Common functions and definitions used by many tests

backends/arm/arm_backend.py

Lines changed: 11 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
import serializer.tosa_serializer as ts
1717
from executorch.backends.arm.arm_vela import vela_compile
1818
from executorch.backends.arm.operators.node_visitor import get_node_visitors
19+
from executorch.backends.arm.operators.op_output import process_output
1920
from executorch.backends.arm.operators.op_placeholder import process_placeholder
20-
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
21-
from executorch.backends.arm.tosa_quant_utils import get_quant_node_dtype, is_quant_node
21+
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
2222
from executorch.backends.arm.tosa_utils import (
2323
dbg_fail,
2424
dbg_tosa_dump,
25-
is_consumer_node_depthwise_conv2d,
26-
is_permute_node_before_addmm,
25+
process_call_function,
2726
)
2827
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2928
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -44,6 +43,7 @@ def __init__(self):
4443
self.compiler_flags = []
4544
self.output_format = None
4645
self.path_for_intermediates = None
46+
# TODO MLETORCH-265 Remove permute_nhwc flag
4747
self.permute_nhwc = False
4848
self.quantize_io = False
4949

@@ -216,18 +216,13 @@ def preprocess( # noqa: C901
216216
artifact_path = None
217217
output_format = ""
218218
compile_flags = []
219-
permute_memory_to_nhwc = False
220219
for spec in compile_spec:
221220
if spec.key == "debug_artifact_path":
222221
artifact_path = spec.value.decode()
223222
if spec.key == "output_format":
224223
output_format = spec.value.decode()
225224
if spec.key == "compile_flags":
226225
compile_flags.append(spec.value.decode())
227-
if spec.key == "permute_memory_format":
228-
memory_format = spec.value.decode()
229-
if memory_format == "nhwc":
230-
permute_memory_to_nhwc = True
231226

232227
# Check that the output format is set in the compile spec
233228
if not output_format:
@@ -241,81 +236,19 @@ def preprocess( # noqa: C901
241236
# Converted output for this subgraph, serializer needs path early as it emits
242237
# const data directly. Path created and data written only in debug builds.
243238
tosa_graph = ts.TosaSerializer(artifact_path)
239+
graph_module = ArmPassManager().transform_to_backend_pipeline(
240+
graph_module=edge_program.graph_module, compile_spec=compile_spec
241+
)
244242

245243
node_visitors = get_node_visitors(edge_program)
246244

247-
for node in edge_program.graph.nodes:
245+
for node in graph_module.graph.nodes:
248246
if node.op == "call_function":
249-
# Unpack arguments and convert
250-
inputs = []
251-
for arg in node.args:
252-
inputs.append(TosaArg(arg))
253-
254-
# Convert output (this node itself)
255-
output = TosaArg(node)
256-
257-
# TODO: fragile code for temporary fix, not all outputs will be
258-
# rank 4
259-
if permute_memory_to_nhwc and len(output.shape) == 4:
260-
# TODO: remove this if check
261-
# this is added because we need to align the quant node
262-
# output shape before the depthwise_conv2d node. The output
263-
# shape between TOSA conv2d and depthwise_conv2d are different.
264-
if (
265-
node.all_input_nodes[0].op
266-
== "placeholder" # check its parent is a placeholder
267-
and is_quant_node(node)
268-
and is_consumer_node_depthwise_conv2d(node)
269-
):
270-
NHWC_Order = [2, 3, 0, 1]
271-
else:
272-
NHWC_Order = [0, 2, 3, 1]
273-
output.shape = [output.shape[i] for i in NHWC_Order]
274-
275-
# Add output to TOSA graph
276-
tosa_graph.currRegion.currBasicBlock.addTensor(
277-
output.name,
278-
(
279-
inputs[0].shape
280-
if is_permute_node_before_addmm(node)
281-
else output.shape
282-
),
283-
(
284-
map_dtype(get_quant_node_dtype(node))
285-
if is_quant_node(node)
286-
else output.dtype
287-
),
288-
)
289-
290-
# Visiting each Node
291-
if node.target.__name__ in node_visitors:
292-
if node.target.__name__ in [
293-
"aten.add.Tensor",
294-
"aten._native_batch_norm_legit_no_training.default",
295-
]:
296-
node_visitors[node.target.__name__].define_node(
297-
node,
298-
tosa_graph,
299-
inputs,
300-
output,
301-
is_quant_node(node),
302-
permute_memory_to_nhwc,
303-
)
304-
else:
305-
node_visitors[node.target.__name__].define_node(
306-
node, tosa_graph, inputs, output, is_quant_node(node)
307-
)
308-
else:
309-
raise RuntimeError(f"Unknown operator {node.target}")
247+
process_call_function(node, tosa_graph, node_visitors)
310248
elif node.op == "placeholder":
311-
process_placeholder(
312-
node, tosa_graph, edge_program, permute_memory_to_nhwc
313-
)
249+
process_placeholder(node, tosa_graph, edge_program)
314250
elif node.op == "output":
315-
for output in node.args[0]:
316-
tosa_graph.addOutputTensor(
317-
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
318-
)
251+
process_output(node, tosa_graph)
319252
else:
320253
# This will only happen if an unpartitioned graph is passed without
321254
# any checking of compatibility.

backends/arm/arm_partitioner.py

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from executorch.backends.arm.arm_backend import ArmBackend
13+
from executorch.backends.arm.passes.tag_io_quant_pass import TagIOQuantPass
1314
from executorch.exir.backend.compile_spec_schema import CompileSpec
1415
from executorch.exir.backend.partitioner import (
1516
DelegationSpec,
@@ -18,6 +19,7 @@
1819
)
1920
from executorch.exir.backend.utils import tag_constant_data
2021
from executorch.exir.dialects._ops import ops as exir_ops
22+
from executorch.exir.passes import PassManager
2123
from torch.export.exported_program import ExportedProgram
2224
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2325

@@ -54,9 +56,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5456
supported &= self.is_node_supported_custom(node)
5557

5658
# Override partitioning based on pre partition passes
57-
if supported and "arm_partition" in node.meta:
58-
supported = supported & node.meta["arm_partition"]
59-
node.meta.pop("arm_partition")
59+
if "arm_override_partition" in node.meta:
60+
supported = supported & node.meta["arm_override_partition"]
61+
node.meta.pop("arm_override_partition")
6062

6163
return supported
6264

@@ -69,54 +71,6 @@ def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
6971
return True
7072

7173

72-
from executorch.exir.pass_base import ExportPass, PassResult
73-
from executorch.exir.passes import PassManager
74-
75-
76-
class TagIOQuant(ExportPass):
77-
"""
78-
Pass run before partitioning to tag Q/DQ on any placeholder and output
79-
to ensure we don't greedily partition them for device. Float conversion
80-
has to happen outside a TOSA base inference profile.
81-
"""
82-
83-
def __init__(self, edge_program: torch.export.ExportedProgram):
84-
super(TagIOQuant, self).__init__()
85-
self.edge_program = edge_program
86-
87-
def is_quant_node(self, node: torch.fx.node.Node):
88-
return node.target in {
89-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
90-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
91-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
92-
}
93-
94-
def is_dequant_node(self, node: torch.fx.node.Node):
95-
return node.target in {
96-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
97-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
98-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
99-
}
100-
101-
def call(self, graph_module: torch.fx.GraphModule):
102-
for node in graph_module.graph.nodes:
103-
# tag q of input
104-
if node.op == "placeholder":
105-
for user in node.users.keys():
106-
# if we have an input going into a quantize
107-
if self.is_quant_node(user):
108-
user.meta["arm_partition"] = False
109-
110-
# tag dq of outputs
111-
if node.op == "output":
112-
quant, *_ = node.args[0]
113-
if self.is_dequant_node(quant):
114-
quant.meta["arm_partition"] = False
115-
116-
graph_module.recompile()
117-
return PassResult(graph_module, True)
118-
119-
12074
@final
12175
class ArmPartitioner(Partitioner):
12276
def __init__(self, compile_spec: List[CompileSpec]) -> None:
@@ -133,7 +87,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
13387
# Exclude IO quantization from the partition
13488
passes = PassManager(
13589
passes=[
136-
TagIOQuant(exported_program),
90+
TagIOQuantPass(),
13791
]
13892
)
13993
passes(exported_program.graph_module)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -9,7 +9,6 @@
99
op_addmm,
1010
op_avg_pool2d,
1111
op_batch_norm,
12-
op_clone,
1312
op_conv2d,
1413
op_dequant,
1514
op_div,

backends/arm/operators/op_add.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
build_rescale_from_int32,
1717
build_rescale_to_int32,
1818
)
19-
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs
19+
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs, tosa_shape
2020
from serializer.tosa_serializer import TosaOp
2121

2222

@@ -34,7 +34,6 @@ def define_node(
3434
inputs: List[TosaArg],
3535
output: TosaArg,
3636
is_quant_node: bool,
37-
permute_memory_to_nhwc: bool,
3837
) -> None:
3938
if is_quant_node:
4039
# Single input or not
@@ -54,12 +53,9 @@ def define_node(
5453
inputA_rescale_scale = input_A_scale.number / min_scale
5554
inputB_rescale_scale = input_B_scale.number / min_scale
5655

56+
input_A.shape = tosa_shape(input_A.shape, input_A.dim_order)
57+
input_B.shape = tosa_shape(input_B.shape, input_B.dim_order)
5758
broadcasted_shape = broadcast_shapes(input_A.shape, input_B.shape)
58-
if permute_memory_to_nhwc:
59-
NHWC_Order = [0, 2, 3, 1]
60-
broadcasted_shape = [broadcasted_shape[i] for i in NHWC_Order]
61-
input_A.shape = [input_A.shape[i] for i in NHWC_Order]
62-
input_B.shape = [input_B.shape[i] for i in NHWC_Order]
6359

6460
input_A_rescaled_to_int32 = build_rescale_to_int32(
6561
tosa_graph,

0 commit comments

Comments
 (0)