Skip to content

Make quantize_pt2 return an ExportedProgram instead of a GraphModule #10644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def quantize_pt2(
quantizer: Optional[CadenceQuantizer] = None,
calibration_data: Optional[list[tuple[object, ...]]] = None,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
) -> ExportedProgram:
"""
Trace, prepare, convert and fuse the model using the given quantizer.
If calibration data is provided, it will be used to calibrate the model. If
Expand All @@ -178,7 +178,9 @@ def quantize_pt2(
logging.info("Graph after quantization and fusion:")
logging.info(fused_gm.graph.print_tabular())

return fused_gm
program = torch.export.export(fused_gm, inputs, strict=True)

return program


# Export the model and lower it to an ExportedProgram (in aten IR)
Expand Down Expand Up @@ -260,21 +262,43 @@ def quantize_and_export_to_edge(
dump_graphs: bool = False,
constant_methods: Optional[dict[str, object]] = None,
) -> EdgeProgramManager:
"""
Trace, quantize and lower a model/inputs pair to edge IR.
"""
quantized_model = quantize_pt2(
model,
inputs,
quantizer=quantizer,
dump_graphs=dump_graphs,
)

return export_to_edge(
return lower_ep_to_edge(
quantized_model,
inputs,
dump_graphs=dump_graphs,
constant_methods=constant_methods,
)


def lower_ep_to_cadence(
program: ExportedProgram,
dump_graphs: bool = False,
opt_level: int = 1,
) -> EdgeProgramManager:
"""
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
"""
edge_prog_manager = lower_ep_to_edge(program, dump_graphs=dump_graphs)
cadence_passes = get_cadence_passes(opt_level)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
cast(
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
)
)
return cadence_prog_manager


def export_to_cadence(
model: torch.nn.Module,
inputs: tuple[object, ...],
Expand All @@ -299,11 +323,14 @@ def quantize_and_export_to_cadence(
dump_graphs: bool = False,
opt_level: int = 1,
) -> EdgeProgramManager:
"""
Trace, quantize, lower a model/inputs pair to edge IR and apply frontend
optimization passes.
"""
quantized_model = quantize_pt2(model, inputs)

return export_to_cadence(
return lower_ep_to_cadence(
quantized_model,
inputs,
opt_level=opt_level,
dump_graphs=dump_graphs,
)
Expand Down
4 changes: 1 addition & 3 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from executorch.backends.cadence.aot.compiler import (
export_to_edge,
quantize_and_export_to_edge,
quantize_pt2,
)
from executorch.backends.cadence.aot.graph_builder import (
GraphBuilder,
Expand Down Expand Up @@ -113,9 +112,8 @@ def forward(self, x, y):
Y = torch.randn(y_shape)
p = ReplaceMatmulWithTransposedMatmulPass()
inputs = (X, Y)
quantized_model = quantize_pt2(model, inputs)
graph_module = (
export_to_edge(quantized_model, inputs).exported_program().graph_module
quantize_and_export_to_edge(model, inputs).exported_program().graph_module
)
# pyre-fixme[16]: Optional type has no attribute `graph_module`
graph_after_passes = p(graph_module).graph_module
Expand Down
Loading