@@ -85,27 +85,10 @@ def canonicalize_program(prog: ExportedProgram):
85
85
QNN_COMPILE_SPEC , convert_to_flatbuffer (options )
86
86
)
87
87
88
-
89
- def capture_program (
90
- module : torch .nn .Module ,
91
- inputs : Tuple [torch .Tensor ],
92
- ) -> exir .ExirExportedProgram :
93
- # TODO: should switch to torch.export.export & custom deomposition
94
- # to reduce maintaining effort.
95
- exir_exported_program = exir .capture (
96
- module ,
97
- inputs ,
98
- qnn_capture_config (),
99
- )
100
- # We choose call_operator by target in ConvertBinaryOpsWithScalar
101
- # because it is the same source_fn_stack for MultiheadAttention
102
- exir_exported_program .transform (ConvertBinaryOpsWithScalar ())
103
- ex_prog = exir_exported_program .to_edge (qnn_edge_config ())
104
-
88
+ def _transform (edge_program : ExportedProgram ) -> None :
105
89
# currently ExirExportedProgram.transform does not accept
106
90
# changes of input number which was caused by FoldQDQ
107
91
# apply passes one by one here to avoid IR capture failure
108
- edge_program = ex_prog .exported_program
109
92
graph_module = edge_program .graph_module
110
93
RemoveClone ()(graph_module )
111
94
RecomposePixelShuffle ()(graph_module )
@@ -121,6 +104,23 @@ def capture_program(
121
104
FoldQDQ ()(graph_module )
122
105
InsertRequantize (edge_program )(graph_module )
123
106
LayoutTransform (edge_program )(graph_module )
107
+
108
+ def capture_program (
109
+ module : torch .nn .Module ,
110
+ inputs : Tuple [torch .Tensor ],
111
+ ) -> exir .ExirExportedProgram :
112
+ # TODO: should switch to torch.export.export & custom deomposition
113
+ # to reduce maintaining effort.
114
+ exir_exported_program = exir .capture (
115
+ module ,
116
+ inputs ,
117
+ qnn_capture_config (),
118
+ )
119
+ # We choose call_operator by target in ConvertBinaryOpsWithScalar
120
+ # because it is the same source_fn_stack for MultiheadAttention
121
+ exir_exported_program .transform (ConvertBinaryOpsWithScalar ())
122
+ ex_prog = exir_exported_program .to_edge (qnn_edge_config ())
123
+ _transform (ex_prog .exported_program )
124
124
return ex_prog
125
125
126
126
0 commit comments