Skip to content

Commit bbf5867

Browse files
author
quic_chuntl
committed
Qualcomm AI Engine Direct - Adapt to new IR capture flow
Summary: - Change existent IR capture flow (exir.capture) to torch.export.export - Add custom decomposition table for mitigating maintaining effort - Fix breakages encountered and make sure all tests passed as well
1 parent 57e192b commit bbf5867

File tree

3 files changed

+39
-31
lines changed

3 files changed

+39
-31
lines changed

backends/qualcomm/tests/utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from executorch.backends.qualcomm.utils.utils import capture_program
2626
from executorch.examples.qualcomm.scripts.utils import SimpleADB
2727

28-
from executorch.exir.backend.backend_api import to_backend
2928
from executorch.exir.backend.compile_spec_schema import CompileSpec
3029
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3130

@@ -107,19 +106,18 @@ def lower_module_and_test_output(
107106
qnn_partitioner = QnnPartitioner(
108107
self.compiler_specs, skip_node_id_set, skip_node_op_set
109108
)
110-
delegated_program = capture_program(module, sample_inputs)
111-
delegated_program.exported_program = to_backend(
112-
delegated_program.exported_program, qnn_partitioner
113-
)
114-
exec_prog = delegated_program.to_executorch()
109+
delegated_program_mgr = capture_program(module, sample_inputs)
110+
delegated_program_mgr = delegated_program_mgr.to_backend(qnn_partitioner)
111+
exec_prog_mgr = delegated_program_mgr.to_executorch()
115112

116113
# Assert the backend name is qnn
117114
self.assertEqual(
118-
len(exec_prog.program.execution_plan[0].delegates), expected_partitions
115+
len(exec_prog_mgr.executorch_program.execution_plan[0].delegates),
116+
expected_partitions,
119117
)
120118
for i in range(expected_partitions):
121119
self.assertEqual(
122-
exec_prog.program.execution_plan[0].delegates[i].id,
120+
exec_prog_mgr.executorch_program.execution_plan[0].delegates[i].id,
123121
QnnBackend.__name__,
124122
)
125123

@@ -132,7 +130,7 @@ def lower_module_and_test_output(
132130
pte_fname,
133131
) = self._save_model_and_expected_output(
134132
module,
135-
exec_prog.buffer,
133+
exec_prog_mgr.buffer,
136134
sample_inputs,
137135
tmp_dir,
138136
)

backends/qualcomm/utils/utils.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List, Tuple
7+
from typing import Callable, Dict, List, Tuple
88

99
import executorch.exir as exir
1010

@@ -20,7 +20,6 @@
2020
)
2121
from executorch.backends.qualcomm.passes.convert_bmm_to_matmul import ConvertBmmToMatmul
2222
from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid
23-
from executorch.backends.qualcomm.passes.convert_hardswish import ConvertHardswish
2423
from executorch.backends.qualcomm.passes.convert_interpolate_with_upsample2d import (
2524
ConvertInterpolateWithUpsample2D,
2625
)
@@ -29,9 +28,6 @@
2928
from executorch.backends.qualcomm.passes.i64_to_i32 import I64toI32
3029
from executorch.backends.qualcomm.passes.insert_requantize import InsertRequantize
3130
from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform
32-
from executorch.backends.qualcomm.passes.recompose_pixel_shuffle import (
33-
RecomposePixelShuffle,
34-
)
3531
from executorch.backends.qualcomm.passes.remove_clone import RemoveClone
3632
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import (
3733
_soc_info_table,
@@ -46,7 +42,10 @@
4642
from executorch.backends.qualcomm.serialization.qnn_compile_spec_serialize import (
4743
convert_to_flatbuffer,
4844
)
45+
from executorch.exir import ExirExportedProgram
4946
from executorch.exir.backend.compile_spec_schema import CompileSpec
47+
from executorch.exir.program._program import to_edge
48+
from torch._decomp import core_aten_decompositions
5049
from torch.fx import passes
5150

5251
QNN_COMPILE_SPEC = "qnn_compile_spec"
@@ -60,32 +59,44 @@ def qnn_edge_config() -> exir.EdgeCompileConfig:
6059
return exir.EdgeCompileConfig(_check_ir_validity=False)
6160

6261

62+
def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
63+
source_decompositions = core_aten_decompositions()
64+
# The below super ops are supported by QNN
65+
remove_decompositions = [
66+
torch.ops.aten.pixel_shuffle.default,
67+
torch.ops.aten.hardswish.default,
68+
]
69+
70+
return {
71+
key: source_decompositions[key]
72+
for key in source_decompositions
73+
if key not in remove_decompositions
74+
}
75+
76+
6377
def capture_program(
6478
module: torch.nn.Module,
6579
inputs: Tuple[torch.Tensor],
6680
) -> exir.ExirExportedProgram:
67-
# TODO: should switch to torch.export.export & custom deomposition
68-
# to reduce maintaining effort.
69-
exir_exported_program = exir.capture(
70-
module,
71-
inputs,
72-
qnn_capture_config(),
73-
)
81+
ep = torch.export.export(module, inputs)
82+
decomposed_ep = ep.run_decompositions(get_decomp_table())
83+
7484
# We choose call_operator by target in ConvertBinaryOpsWithScalar
7585
# because it is the same source_fn_stack for MultiheadAttention
76-
exir_exported_program.transform(ConvertBinaryOpsWithScalar())
77-
ex_prog = exir_exported_program.to_edge(qnn_edge_config())
86+
# TODO: Should modify the scalar op in the op builder instead of
87+
# using transformation
88+
core_ep = ExirExportedProgram(decomposed_ep, False)
89+
core_ep.transform(ConvertBinaryOpsWithScalar())
90+
edge_ep_mgr = to_edge(core_ep.exported_program, compile_config=qnn_edge_config())
7891

7992
# currently ExirExportedProgram.transform does not accept
8093
# changes of input number which was caused by FoldQDQ
8194
# apply passes one by one here to avoid IR capture failure
82-
edge_program = ex_prog.exported_program
95+
edge_program = edge_ep_mgr.exported_program()
8396
graph_module = edge_program.graph_module
8497
RemoveClone()(graph_module)
85-
RecomposePixelShuffle()(graph_module)
8698
ConvertToLinear()(graph_module)
8799
ConvertHardsigmoid()(graph_module)
88-
ConvertHardswish()(graph_module)
89100
ConvertBmmToMatmul()(graph_module)
90101
ConvertInterpolateWithUpsample2D()(graph_module)
91102
I64toI32(edge_program)(graph_module)
@@ -95,7 +106,7 @@ def capture_program(
95106
FoldQDQ()(graph_module)
96107
InsertRequantize(edge_program)(graph_module)
97108
LayoutTransform(edge_program)(graph_module)
98-
return ex_prog
109+
return edge_ep_mgr
99110

100111

101112
def draw_graph(title, path, graph_module: torch.fx.GraphModule):

examples/qualcomm/scripts/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
capture_program,
2727
generate_qnn_executorch_compiler_spec,
2828
)
29-
from executorch.exir.backend.backend_api import to_backend
3029
from executorch.exir.capture._config import ExecutorchBackendConfig
3130
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3231

@@ -185,9 +184,9 @@ def build_executorch_binary(
185184
skip_node_id_set,
186185
skip_node_op_set,
187186
)
188-
edge_prog.exported_program = to_backend(edge_prog.exported_program, qnn_partitioner)
189-
edge_prog.exported_program.graph_module.graph.print_tabular()
190-
exec_prog = edge_prog.to_executorch(
187+
delegated_program_mgr = edge_prog.to_backend(qnn_partitioner)
188+
edge_prog.exported_program().graph_module.graph.print_tabular()
189+
exec_prog = delegated_program_mgr.to_executorch(
191190
config=ExecutorchBackendConfig(extract_constant_segment=False)
192191
)
193192
with open(f"{file_name}.pte", "wb") as file:

0 commit comments

Comments
 (0)