4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import List , Tuple
7
+ from typing import Callable , Dict , List , Tuple
8
8
9
9
import executorch .exir as exir
10
10
20
20
)
21
21
from executorch .backends .qualcomm .passes .convert_bmm_to_matmul import ConvertBmmToMatmul
22
22
from executorch .backends .qualcomm .passes .convert_hardsigmoid import ConvertHardsigmoid
23
- from executorch .backends .qualcomm .passes .convert_hardswish import ConvertHardswish
24
23
from executorch .backends .qualcomm .passes .convert_interpolate_with_upsample2d import (
25
24
ConvertInterpolateWithUpsample2D ,
26
25
)
29
28
from executorch .backends .qualcomm .passes .i64_to_i32 import I64toI32
30
29
from executorch .backends .qualcomm .passes .insert_requantize import InsertRequantize
31
30
from executorch .backends .qualcomm .passes .layout_transform import LayoutTransform
32
- from executorch .backends .qualcomm .passes .recompose_pixel_shuffle import (
33
- RecomposePixelShuffle ,
34
- )
35
31
from executorch .backends .qualcomm .passes .remove_clone import RemoveClone
36
32
from executorch .backends .qualcomm .serialization .qnn_compile_spec_schema import (
37
33
_soc_info_table ,
46
42
from executorch .backends .qualcomm .serialization .qnn_compile_spec_serialize import (
47
43
convert_to_flatbuffer ,
48
44
)
45
+ from executorch .exir import ExirExportedProgram
49
46
from executorch .exir .backend .compile_spec_schema import CompileSpec
47
+ from executorch .exir .program ._program import to_edge
48
+ from torch ._decomp import core_aten_decompositions
50
49
from torch .fx import passes
51
50
52
51
QNN_COMPILE_SPEC = "qnn_compile_spec"
@@ -60,32 +59,44 @@ def qnn_edge_config() -> exir.EdgeCompileConfig:
60
59
return exir .EdgeCompileConfig (_check_ir_validity = False )
61
60
62
61
62
+ def get_decomp_table () -> Dict [torch ._ops .OperatorBase , Callable ]:
63
+ source_decompositions = core_aten_decompositions ()
64
+ # The below super ops are supported by QNN
65
+ remove_decompositions = [
66
+ torch .ops .aten .pixel_shuffle .default ,
67
+ torch .ops .aten .hardswish .default ,
68
+ ]
69
+
70
+ return {
71
+ key : source_decompositions [key ]
72
+ for key in source_decompositions
73
+ if key not in remove_decompositions
74
+ }
75
+
76
+
63
77
def capture_program (
64
78
module : torch .nn .Module ,
65
79
inputs : Tuple [torch .Tensor ],
66
80
) -> exir .ExirExportedProgram :
67
- # TODO: should switch to torch.export.export & custom deomposition
68
- # to reduce maintaining effort.
69
- exir_exported_program = exir .capture (
70
- module ,
71
- inputs ,
72
- qnn_capture_config (),
73
- )
81
+ ep = torch .export .export (module , inputs )
82
+ decomposed_ep = ep .run_decompositions (get_decomp_table ())
83
+
74
84
# We choose call_operator by target in ConvertBinaryOpsWithScalar
75
85
# because it is the same source_fn_stack for MultiheadAttention
76
- exir_exported_program .transform (ConvertBinaryOpsWithScalar ())
77
- ex_prog = exir_exported_program .to_edge (qnn_edge_config ())
86
+ # TODO: Should modify the scalar op in the op builder instead of
87
+ # using transformation
88
+ core_ep = ExirExportedProgram (decomposed_ep , False )
89
+ core_ep .transform (ConvertBinaryOpsWithScalar ())
90
+ edge_ep_mgr = to_edge (core_ep .exported_program , compile_config = qnn_edge_config ())
78
91
79
92
# currently ExirExportedProgram.transform does not accept
80
93
# changes of input number which was caused by FoldQDQ
81
94
# apply passes one by one here to avoid IR capture failure
82
- edge_program = ex_prog .exported_program
95
+ edge_program = edge_ep_mgr .exported_program ()
83
96
graph_module = edge_program .graph_module
84
97
RemoveClone ()(graph_module )
85
- RecomposePixelShuffle ()(graph_module )
86
98
ConvertToLinear ()(graph_module )
87
99
ConvertHardsigmoid ()(graph_module )
88
- ConvertHardswish ()(graph_module )
89
100
ConvertBmmToMatmul ()(graph_module )
90
101
ConvertInterpolateWithUpsample2D ()(graph_module )
91
102
I64toI32 (edge_program )(graph_module )
@@ -95,7 +106,7 @@ def capture_program(
95
106
FoldQDQ ()(graph_module )
96
107
InsertRequantize (edge_program )(graph_module )
97
108
LayoutTransform (edge_program )(graph_module )
98
- return ex_prog
109
+ return edge_ep_mgr
99
110
100
111
101
112
def draw_graph (title , path , graph_module : torch .fx .GraphModule ):
0 commit comments