Skip to content

Commit 6be49e1

Browse files
committed
Update on "[4/N] Add backend options map"
This is to manage the backend <-> BackendOptions map. Users will create the bakcend options map, and ET runtime will read the backend name, and dispatch the list of backend options to each backend. exported-using-ghexport Differential Revision: [D76149466](https://our.internmc.facebook.com/intern/diff/D76149466/) Differential Revision: [D76149466](https://our.internmc.facebook.com/intern/diff/D76149466) [ghstack-poisoned]
2 parents ff6bfd9 + 49a3aa5 commit 6be49e1

40 files changed

+776
-239
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ It supports a wide range of models including LLMs (Large Language Models), CV (C
1919
Platform Support:
2020
- Operating Systems:
2121
- iOS
22-
- Mac
22+
- MacOS (ARM64)
2323
- Android
2424
- Linux
2525
- Microcontrollers

backends/arm/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if(NOT EXECUTORCH_ROOT)
1212
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
1313
endif()
1414

15+
add_compile_options("-Wall" "-Werror")
16+
1517
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
1618

1719
set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .annotate_quant_attrs import AnnotateQuantAttrs
99
from .annotate_stack import AnnotateStack
1010
from .annotate_unbind import AnnotateUnbind
11+
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1112
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
1213
from .convert_square_to_pow import ConvertSquareToPow
1314
from .decompose_any import DecomposeAny
@@ -35,7 +36,6 @@
3536
from .remove_0d_tensor import Remove0DTensor
3637
from .remove_redundancy import RemoveRedundancy
3738
from .replace_arange_args import ReplaceArangeArgs
38-
from .replace_index_put_input import ReplaceIndexPutInput
3939
from .replace_inf_values import ReplaceInfValues
4040
from .tag_quant_io import TagQuantIO
4141

@@ -45,6 +45,7 @@
4545
AnnotateQuantAttrs,
4646
AnnotateStack,
4747
AnnotateUnbind,
48+
ConvertBmmToMatmul,
4849
ConvertConv1dToConv2d,
4950
ConvertSquareToPow,
5051
DecomposeAny,
@@ -72,7 +73,6 @@
7273
Remove0DTensor,
7374
RemoveRedundancy,
7475
ReplaceArangeArgs,
75-
ReplaceIndexPutInput,
7676
ReplaceInfValues,
7777
TagQuantIO,
7878
]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import operator
7+
from collections import Counter
8+
from typing import List
9+
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
14+
15+
16+
class ConvertBmmToMatmul(ExportPass):
17+
"""
18+
Replace bmm to matmul, because bmm is eqaul to matmul in QNN.
19+
Handle missing quantization tag for bmm op.
20+
"""
21+
22+
view_copy = exir_ops.edge.aten.view_copy.default
23+
expand_copy = exir_ops.edge.aten.expand_copy.default
24+
clone = exir_ops.edge.aten.clone.default
25+
bmm = exir_ops.edge.aten.bmm.default
26+
matmul = exir_ops.edge.aten.matmul.default
27+
patterns = [
28+
{expand_copy: 2, view_copy: 3, bmm: 1},
29+
{expand_copy: 2, view_copy: 3, bmm: 1, clone: 1},
30+
{bmm: 1},
31+
]
32+
33+
def __init__(self):
34+
super(ConvertBmmToMatmul, self).__init__()
35+
36+
def _get_ordered_inputs(
37+
self, inputs: List[torch.fx.Node], output: torch.fx.Node
38+
) -> List[torch.fx.Node]:
39+
bmm_inputs = []
40+
for arg in output.args:
41+
while arg not in inputs:
42+
arg = arg.args[0]
43+
bmm_inputs.append(arg)
44+
return bmm_inputs
45+
46+
def call(self, graph_module: torch.fx.GraphModule):
47+
graph = graph_module.graph
48+
partitions = get_source_partitions(
49+
graph,
50+
[operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
51+
)
52+
for _, src_partitions in partitions.items():
53+
for src_partition in src_partitions:
54+
op_cnt = Counter([n.target for n in src_partition.nodes])
55+
if op_cnt not in self.patterns:
56+
raise AssertionError(
57+
"Found a new pattern needed be converted to linear op"
58+
)
59+
60+
inputs = src_partition.input_nodes
61+
bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0]
62+
output = src_partition.output_nodes[0]
63+
# the order of src_partition.inputs is not guaranteed.
64+
lhs, rhs = self._get_ordered_inputs(inputs, bmm_node)
65+
with graph_module.graph.inserting_before(output):
66+
# replace bmm to matmul, because bmm is eqaul to matmul in qnn.
67+
matmul_node = graph.create_node(
68+
"call_function", self.matmul, (lhs, rhs)
69+
)
70+
matmul_node.meta = output.meta
71+
for user in output.users.copy():
72+
user.replace_input_with(output, matmul_node)
73+
74+
graph.eliminate_dead_code()
75+
graph_module.recompile()
76+
return PassResult(graph_module, True)

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
from executorch.backends.qualcomm.builders.node_visitor import q_ops
1111

12-
from executorch.backends.qualcomm.builders.utils import is_parameter
12+
from executorch.backends.qualcomm.builders.utils import (
13+
is_mutable_buffer_input,
14+
is_parameter,
15+
)
1316
from executorch.backends.qualcomm.utils.constants import (
1417
QCOM_ENCODING,
1518
QCOM_QUANT_ATTRS,
@@ -124,7 +127,10 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
124127
if (
125128
n.op == "placeholder"
126129
and n.meta.get(QCOM_QUANT_ATTRS)
127-
and not is_parameter(n, self.edge_program)
130+
and (
131+
not is_parameter(n, self.edge_program)
132+
or is_mutable_buffer_input(n, self.edge_program)
133+
)
128134
):
129135
self._insert_quant_node(
130136
graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AnnotateQuantAttrs,
1414
AnnotateStack,
1515
AnnotateUnbind,
16+
ConvertBmmToMatmul,
1617
ConvertConv1dToConv2d,
1718
ConvertSquareToPow,
1819
DecomposeAny,
@@ -40,7 +41,6 @@
4041
Remove0DTensor,
4142
RemoveRedundancy,
4243
ReplaceArangeArgs,
43-
ReplaceIndexPutInput,
4444
ReplaceInfValues,
4545
TagQuantIO,
4646
)
@@ -80,6 +80,7 @@ def get_capture_program_passes():
8080
(AnnotateQuantAttrs, True),
8181
(AnnotateStack, True),
8282
(AnnotateUnbind, True),
83+
(ConvertBmmToMatmul, False),
8384
(ConvertConv1dToConv2d, True),
8485
(DecomposeAny, True),
8586
(DecomposeColIm, True),
@@ -92,7 +93,6 @@ def get_capture_program_passes():
9293
(RecomposeRmsNorm, False),
9394
(Remove0DTensor, True),
9495
(RemoveRedundancy, True),
95-
(ReplaceIndexPutInput, True),
9696
(TagQuantIO, False),
9797
]
9898

@@ -224,4 +224,11 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
224224
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
225225
self.add_pass(FuseConsecutiveCast())
226226
self.add_pass(FuseConsecutiveTranspose())
227-
return self._transform(exported_program.graph_module)
227+
self._transform(exported_program.graph_module)
228+
# Update inputs_to_buffers and buffers_to_mutate in graph signature for mutable buffer
229+
# Since I/O will be inserted Q/DQ, it results in failed to mapping output node names and buffer
230+
exported_program._graph_signature = _get_updated_graph_signature(
231+
exported_program.graph_signature,
232+
exported_program.graph_module,
233+
)
234+
return exported_program.graph_module

backends/qualcomm/_passes/replace_index_put_input.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

backends/qualcomm/_passes/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def get_passes_dependency_for_capture_program():
6464
AnnotateQuantAttrs,
6565
AnnotateStack,
6666
AnnotateUnbind,
67+
ConvertBmmToMatmul,
6768
ConvertConv1dToConv2d,
6869
DecomposeAny,
6970
DecomposeColIm,
@@ -76,18 +77,19 @@ def get_passes_dependency_for_capture_program():
7677
RecomposePixelUnshuffle,
7778
RecomposeRmsNorm,
7879
RemoveRedundancy,
79-
ReplaceIndexPutInput,
8080
TagQuantIO,
8181
)
8282

8383
return {
8484
AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
8585
AnnotateQuantAttrs: [
86+
ConvertBmmToMatmul,
8687
RecomposePixelUnshuffle,
8788
RemoveRedundancy,
8889
],
8990
AnnotateStack: [RemoveRedundancy],
9091
AnnotateUnbind: [RemoveRedundancy],
92+
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
9193
DecomposeAny: [RemoveRedundancy],
9294
DecomposeColIm: [FoldQDQ],
9395
DecomposeLinalgVectorNorm: [RemoveRedundancy],
@@ -103,8 +105,7 @@ def get_passes_dependency_for_capture_program():
103105
],
104106
RecomposePixelUnshuffle: [RemoveRedundancy],
105107
RecomposeRmsNorm: [RemoveRedundancy],
106-
ReplaceIndexPutInput: [LayoutTransform],
107-
TagQuantIO: [ReplaceIndexPutInput],
108+
TagQuantIO: [LayoutTransform],
108109
}
109110

110111

backends/qualcomm/builders/node_visitor.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
get_parameter,
4242
is_graph_input,
4343
is_graph_output,
44+
is_mutable_buffer_input,
45+
is_mutable_buffer_output,
4446
is_parameter,
4547
)
4648

@@ -307,7 +309,9 @@ def get_tensor_type(
307309
node: torch.fx.Node,
308310
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
309311
) -> PyQnnWrapper.Qnn_TensorType_t:
310-
is_input = is_graph_input(node, self.edge_program)
312+
is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input(
313+
node, self.edge_program
314+
)
311315
is_output = is_graph_output(node)
312316
# handle logic for input/output tensors
313317
if is_input or is_output:
@@ -352,6 +356,33 @@ def get_dynamic_dimension(self, dims):
352356

353357
return dynamic_dims if any(dynamic_dims) else [], nominal_dims
354358

359+
def get_tensor_name(
360+
self,
361+
node: torch.fx.Node,
362+
wrapper_idx: int = 0,
363+
):
364+
tensor_name = f"{node.name}_{wrapper_idx}"
365+
# The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess,
366+
# the input order between QNN and the original graph’s forward function may differ.
367+
# The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime.
368+
# The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump.
369+
if is_mutable_buffer_input(node, self.edge_program):
370+
fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target]
371+
position_index = list(
372+
self.edge_program.graph_signature.buffers_to_mutate.values()
373+
).index(fqn)
374+
tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}"
375+
elif is_graph_input(node, self.edge_program):
376+
tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}"
377+
elif is_mutable_buffer_output(node, self.edge_program):
378+
position_index = list(
379+
self.edge_program.graph_signature.buffers_to_mutate.keys()
380+
).index(node.name)
381+
tensor_name = f"output_mutbuf_{position_index}_{tensor_name}"
382+
elif is_graph_output(node):
383+
tensor_name = f"output_{tensor_name}"
384+
return tensor_name
385+
355386
def define_custom_tensor_wrapper(
356387
self,
357388
node_name: str,
@@ -413,16 +444,7 @@ def define_tensor(
413444
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
414445
return cached
415446

416-
tensor_name = f"{tensor_source_node.name}_{wrapper_idx}"
417-
if is_graph_input(tensor_source_node, self.edge_program):
418-
tensor_name = (
419-
"input_"
420-
+ str(self.external_ids[tensor_source_node])
421-
+ "_"
422-
+ tensor_name
423-
)
424-
if is_graph_output(tensor_source_node):
425-
tensor_name = "output_" + tensor_name
447+
tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx)
426448
dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size()
427449
dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims)
428450
tensor_type = self.get_tensor_type(tensor_source_node, tensor_type)

backends/qualcomm/builders/node_visitor_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .node_visitor import NodeVisitor
1515
from .op_custom_op import CustomOp
16-
from .utils import is_graph_input, is_graph_output
16+
from .utils import is_graph_input, is_graph_output, is_mutable_buffer_input
1717

1818

1919
# This will hold mapping of all node names to the visitor class
@@ -39,7 +39,9 @@ def generate_node_to_external_map(
3939
# The order in which we visit the placeholder node is same as the *args
4040
# order for the forward(*args) signature for this gm. Using the order of
4141
# the nodes as external_id to extract the right arg from *args at runtime
42-
if is_graph_input(node, edge_program):
42+
if is_graph_input(node, edge_program) or is_mutable_buffer_input(
43+
node, edge_program
44+
):
4345
node_to_external_map[node] = len(node_to_external_map)
4446
for node in edge_program.graph_module.graph.nodes:
4547
if is_graph_output(node):

0 commit comments

Comments
 (0)