@@ -56,6 +56,7 @@ def convert_pt2(
56
56
model : torch .nn .Module ,
57
57
inputs : tuple [object , ...],
58
58
quantizer : CadenceQuantizer ,
59
+ dump_graphs : bool = False ,
59
60
) -> torch .fx .GraphModule :
60
61
"""
61
62
Prepare and convert a model using the given quantizer.
@@ -86,6 +87,10 @@ def convert_pt2(
86
87
.module ()
87
88
)
88
89
90
+ if dump_graphs :
91
+ logging .info ("Graph before quantization:" )
92
+ logging .info (model_gm .graph .print_tabular ())
93
+
89
94
# Prepare
90
95
prepared_model = prepare_pt2e (model_gm , quantizer )
91
96
@@ -95,6 +100,10 @@ def convert_pt2(
95
100
# Convert
96
101
converted_model = convert_pt2e (prepared_model )
97
102
103
+ if dump_graphs :
104
+ logging .info ("Graph after quantization (before fusion):" )
105
+ logging .info (model_gm .graph .print_tabular ())
106
+
98
107
return converted_model
99
108
100
109
@@ -127,6 +136,7 @@ def quantize_pt2(
127
136
model : torch .nn .Module ,
128
137
inputs : tuple [object , ...],
129
138
quantizer : Optional [CadenceQuantizer ] = None ,
139
+ dump_graphs : bool = False ,
130
140
) -> torch .fx .GraphModule :
131
141
"""
132
142
Prepare, convert and fuse the model using the given quantizer.
@@ -140,19 +150,22 @@ def quantize_pt2(
140
150
quantizer = CadenceDefaultQuantizer ()
141
151
142
152
# Get converted graph module
143
- converted_gm = convert_pt2 (model , inputs , quantizer )
153
+ converted_gm = convert_pt2 (model , inputs , quantizer , dump_graphs )
144
154
145
155
# Get fused model
146
156
fused_gm = fuse_pt2 (converted_gm , quantizer )
147
157
158
+ if dump_graphs :
159
+ logging .info ("Graph after quantization and fusion:" )
160
+ logging .info (fused_gm .graph .print_tabular ())
161
+
148
162
return fused_gm
149
163
150
164
151
165
# Export the model and lower it to an ExportedProgram (in aten IR)
152
166
def export_program (
153
167
model : torch .nn .Module ,
154
168
inputs : tuple [object , ...],
155
- dump_graphs : bool = False ,
156
169
) -> ExportedProgram :
157
170
assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
158
171
@@ -162,10 +175,6 @@ def export_program(
162
175
# Export the model and return it.
163
176
expo_program = export (model , inputs , strict = True )
164
177
165
- if dump_graphs :
166
- logging .info ("Exported graph:" )
167
- expo_program .graph_module .graph .print_tabular ()
168
-
169
178
return expo_program
170
179
171
180
@@ -179,7 +188,7 @@ def export_to_edge(
179
188
assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
180
189
181
190
# Export the model into an ExportedProgram.
182
- expo_program = export_program (model , inputs , dump_graphs = dump_graphs )
191
+ expo_program = export_program (model , inputs )
183
192
184
193
# Call to_edge to convert the graph to edge IR.
185
194
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
@@ -200,8 +209,10 @@ def export_to_edge(
200
209
)
201
210
202
211
if dump_graphs :
203
- logging .info ("Edge graph:" )
204
- edge_prog_manager .exported_program ().graph_module .graph .print_tabular ()
212
+ logging .info ("Graph after Edge lowering:" )
213
+ logging .info (
214
+ edge_prog_manager .exported_program ().graph_module .graph .print_tabular ()
215
+ )
205
216
206
217
return edge_prog_manager
207
218
0 commit comments