Skip to content

Commit 553b669

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Remove pt2_quant flag
Summary: It's been on the to-do list for a while to clean that up. It's only used in `export_program` to properly put the model in eval mode. Now that we only allow `nn.Module`, there are only two cases: `nn.Module`, which will have `eval()`, and `GraphModule`, which can use `torch.ao.quantization.move_exported_model_to_eval`, which we already called before with the `pt2_quant` flag. Now that the flag is not needed, remove it everywhere! We also promote the `quantize_and_export_program` function to `__init__.py` as a compiler API, because it can be quite useful. Differential Revision: D57491621
1 parent 0f21c66 commit 553b669

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

backends/cadence/aot/compiler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616

1717

1818
def export_program(
19-
model: Callable,
19+
model: torch.nn.Module,
2020
inputs: Any,
2121
pt2_quant: bool = False,
2222
) -> ExportedProgram:
23-
# we don't support training mode. Make it eval
23+
# We don't support training mode. Make it eval
2424
if hasattr(model, "eval"):
25-
if pt2_quant:
26-
# pyre-fixme[6]: Incompatible parameter type.
25+
# If the model is already a GraphModule (most likely from quantization),
26+
# it can't call eval. Call the suggested torch.ao.quantization API instead,
27+
# which only does dropout and batchnorm.
28+
if isinstance(model, torch.fx.GraphModule):
2729
torch.ao.quantization.move_exported_model_to_eval(model)
2830
else:
29-
# pyre-fixme[16]: Anonymous callable has no attribute `eval`.
3031
model.eval()
3132

3233
# if it's already an ExportedProgram, just return it
@@ -46,11 +47,10 @@ def export_program(
4647
def export_to_edge(
4748
model: Callable,
4849
inputs: Any,
49-
pt2_quant: bool = False,
5050
dump_graphs: bool = False,
5151
) -> Tuple[EdgeProgramManager, ExportedProgram]:
5252
# Export the model into an ExportedProgram.
53-
expo_program = export_program(model, inputs, pt2_quant)
53+
expo_program = export_program(model, inputs)
5454

5555
if dump_graphs:
5656
logging.info(f"Exported graph:\n{expo_program.graph_module.graph}")

backends/cadence/aot/export_example.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ def export_model(
6969
QuantFusion(patterns)(converted_model)
7070

7171
# Get edge program (note: the name will change to export_to_cadence in future PRs)
72-
edge_prog_manager, expo_prog = export_to_edge(
73-
converted_model, example_inputs, pt2_quant=True
74-
)
72+
edge_prog_manager, expo_prog = export_to_edge(converted_model, example_inputs)
7573

7674
# Run a couple required passes for quant/dequant ops
7775
cadence_prog_manager = edge_prog_manager.transform(

0 commit comments

Comments
 (0)