Skip to content

Commit f8eec08

Browse files
authored
Improve graph logging for debug purposes
Differential Revision: D68636227 Pull Request resolved: #7991
1 parent c361431 commit f8eec08

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

backends/cadence/aot/compiler.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def convert_pt2(
5656
model: torch.nn.Module,
5757
inputs: tuple[object, ...],
5858
quantizer: CadenceQuantizer,
59+
dump_graphs: bool = False,
5960
) -> torch.fx.GraphModule:
6061
"""
6162
Prepare and convert a model using the given quantizer.
@@ -86,6 +87,10 @@ def convert_pt2(
8687
.module()
8788
)
8889

90+
if dump_graphs:
91+
logging.info("Graph before quantization:")
92+
logging.info(model_gm.graph.print_tabular())
93+
8994
# Prepare
9095
prepared_model = prepare_pt2e(model_gm, quantizer)
9196

@@ -95,6 +100,10 @@ def convert_pt2(
95100
# Convert
96101
converted_model = convert_pt2e(prepared_model)
97102

103+
if dump_graphs:
104+
logging.info("Graph after quantization (before fusion):")
105+
logging.info(model_gm.graph.print_tabular())
106+
98107
return converted_model
99108

100109

@@ -127,6 +136,7 @@ def quantize_pt2(
127136
model: torch.nn.Module,
128137
inputs: tuple[object, ...],
129138
quantizer: Optional[CadenceQuantizer] = None,
139+
dump_graphs: bool = False,
130140
) -> torch.fx.GraphModule:
131141
"""
132142
Prepare, convert and fuse the model using the given quantizer.
@@ -140,19 +150,22 @@ def quantize_pt2(
140150
quantizer = CadenceDefaultQuantizer()
141151

142152
# Get converted graph module
143-
converted_gm = convert_pt2(model, inputs, quantizer)
153+
converted_gm = convert_pt2(model, inputs, quantizer, dump_graphs)
144154

145155
# Get fused model
146156
fused_gm = fuse_pt2(converted_gm, quantizer)
147157

158+
if dump_graphs:
159+
logging.info("Graph after quantization and fusion:")
160+
logging.info(fused_gm.graph.print_tabular())
161+
148162
return fused_gm
149163

150164

151165
# Export the model and lower it to an ExportedProgram (in aten IR)
152166
def export_program(
153167
model: torch.nn.Module,
154168
inputs: tuple[object, ...],
155-
dump_graphs: bool = False,
156169
) -> ExportedProgram:
157170
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
158171

@@ -162,10 +175,6 @@ def export_program(
162175
# Export the model and return it.
163176
expo_program = export(model, inputs, strict=True)
164177

165-
if dump_graphs:
166-
logging.info("Exported graph:")
167-
expo_program.graph_module.graph.print_tabular()
168-
169178
return expo_program
170179

171180

@@ -179,7 +188,7 @@ def export_to_edge(
179188
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
180189

181190
# Export the model into an ExportedProgram.
182-
expo_program = export_program(model, inputs, dump_graphs=dump_graphs)
191+
expo_program = export_program(model, inputs)
183192

184193
# Call to_edge to convert the graph to edge IR.
185194
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
@@ -200,8 +209,10 @@ def export_to_edge(
200209
)
201210

202211
if dump_graphs:
203-
logging.info("Edge graph:")
204-
edge_prog_manager.exported_program().graph_module.graph.print_tabular()
212+
logging.info("Graph after Edge lowering:")
213+
logging.info(
214+
edge_prog_manager.exported_program().graph_module.graph.print_tabular()
215+
)
205216

206217
return edge_prog_manager
207218

0 commit comments

Comments
 (0)