Skip to content

Commit a707550

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Remove pt2_quant flag (#3676)
Summary: Pull Request resolved: #3676 Reviewed By: dulinriley Differential Revision: D57491621 fbshipit-source-id: 6a63e239839be950948085e392604c0ffc62e01a
1 parent 07dcf35 commit a707550

File tree

2 files changed

+14
-22
lines changed

2 files changed

+14
-22
lines changed

backends/cadence/aot/compiler.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from typing import Any, Callable, Tuple
8+
from typing import Any, Tuple
99

1010
import torch
1111

@@ -16,25 +16,20 @@
1616

1717

1818
def export_program(
19-
model: Callable,
19+
model: torch.nn.Module,
2020
inputs: Any,
21-
pt2_quant: bool = False,
2221
) -> ExportedProgram:
23-
# we don't support training mode. Make it eval
24-
if hasattr(model, "eval"):
25-
if pt2_quant:
26-
# pyre-fixme[6]: Incompatible parameter type.
27-
torch.ao.quantization.move_exported_model_to_eval(model)
28-
else:
29-
# pyre-fixme[16]: Anonymous callable has no attribute `eval`.
30-
model.eval()
31-
32-
# if it's already an ExportedProgram, just return it
33-
if isinstance(model, ExportedProgram):
34-
return model
35-
3622
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
3723

24+
# If the model is already a GraphModule (most likely from quantization), call the
25+
# suggested torch.ao.quantization API instead, which only does dropout and batchnorm.
26+
if isinstance(model, torch.fx.GraphModule):
27+
torch.ao.quantization.move_exported_model_to_eval(model)
28+
else:
29+
# We don't support training mode. Make it eval
30+
if hasattr(model, "eval"):
31+
model.eval()
32+
3833
# Prevent mkldnn decompositions
3934
torch._C._set_mkldnn_enabled(False)
4035

@@ -44,13 +39,12 @@ def export_program(
4439

4540
# Export the model and lower it it edge IR.
4641
def export_to_edge(
47-
model: Callable,
42+
model: torch.nn.Module,
4843
inputs: Any,
49-
pt2_quant: bool = False,
5044
dump_graphs: bool = False,
5145
) -> Tuple[EdgeProgramManager, ExportedProgram]:
5246
# Export the model into an ExportedProgram.
53-
expo_program = export_program(model, inputs, pt2_quant)
47+
expo_program = export_program(model, inputs)
5448

5549
if dump_graphs:
5650
logging.info(f"Exported graph:\n{expo_program.graph_module.graph}")

backends/cadence/aot/export_example.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ def export_model(
6969
QuantFusion(patterns)(converted_model)
7070

7171
# Get edge program (note: the name will change to export_to_cadence in future PRs)
72-
edge_prog_manager, expo_prog = export_to_edge(
73-
converted_model, example_inputs, pt2_quant=True
74-
)
72+
edge_prog_manager, expo_prog = export_to_edge(converted_model, example_inputs)
7573

7674
# Run a couple required passes for quant/dequant ops
7775
cadence_prog_manager = edge_prog_manager.transform(

0 commit comments

Comments
 (0)