Skip to content

Commit 8fc3e20

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Cleanup export_model API calls (#5882)
Summary: Pull Request resolved: #5882 Lots of things are redundant and a few need to move to utils. Subsequent changes will split the export function and separate the run part. Main changes: - call `fuse_pt2` after `convert_pt2` instead of `quantize_pt2`, and avoid calling `convert_pt2` twice - move `print_ops_info` into `export_to_cadence` - remove the need to call `export_to_edge` in `export_model` - move the serialization utils to `utils.py` Reviewed By: zonglinpeng Differential Revision: D63795843 fbshipit-source-id: 7eb482b0daccf64d3f1ca73ffb5e5148584a6678
1 parent 84f5a56 commit 8fc3e20

File tree

4 files changed

+64
-58
lines changed

4 files changed

+64
-58
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ python_library(
2222
deps = [
2323
"fbsource//third-party/pypi/tabulate:tabulate",
2424
"//caffe2:torch",
25+
"//executorch/exir:lib",
2526
"//executorch/exir:memory",
2627
"//executorch/exir/dialects:lib",
2728
"//executorch/exir/dialects/edge:lib",

backends/cadence/aot/compiler.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from torch.export import export
3737
from torch.export.exported_program import ExportedProgram
3838

39+
from .utils import print_ops_info
40+
3941

4042
# Note: this is not meant as a primary API since it can create inconsistencies
4143
# if the quantizer here is different from the quantizer used to convert. It is
@@ -193,16 +195,17 @@ def export_to_edge(
193195

194196

195197
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
196-
# apply passes specific to Cadence DSP execution.
198+
# apply passes specific to Cadence DSP execution. Return both to print the
199+
# differences.
197200
def export_to_cadence(
198201
model: torch.nn.Module,
199202
inputs: tuple[object, ...],
200203
dump_graphs: bool = False,
201204
) -> EdgeProgramManager:
202-
edge_program_manager = export_to_edge(model, inputs)
205+
edge_prog_manager = export_to_edge(model, inputs)
203206

204207
# Run a couple required passes for quant/dequant ops
205-
cadence_program_manager = edge_program_manager.transform(
208+
cadence_prog_manager = edge_prog_manager.transform(
206209
[
207210
InitializePipeline(),
208211
RemoveZeroSizedCatArgsPass(),
@@ -216,4 +219,10 @@ def export_to_cadence(
216219
]
217220
)
218221

219-
return cadence_program_manager
222+
# Print some information to terminal
223+
print_ops_info(
224+
edge_prog_manager.exported_program().graph_module,
225+
cadence_prog_manager.exported_program().graph_module,
226+
)
227+
228+
return cadence_prog_manager

backends/cadence/aot/export_example.py

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -10,61 +10,26 @@
1010
import tempfile
1111

1212
from executorch.backends.cadence.aot.ops_registrations import * # noqa
13-
import os
1413
from typing import Any, Tuple
1514

1615
from executorch.backends.cadence.aot.compiler import (
1716
convert_pt2,
1817
export_to_cadence,
19-
export_to_edge,
20-
quantize_pt2,
18+
fuse_pt2,
2119
)
2220
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
2321
from executorch.backends.cadence.runtime import runtime
2422
from executorch.backends.cadence.runtime.executor import BundledProgramManager
2523
from executorch.exir import ExecutorchProgramManager
2624
from torch import nn
2725

28-
from .utils import print_ops_info
26+
from .utils import save_bpte_program, save_pte_program
2927

3028

3129
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3230
logging.basicConfig(level=logging.INFO, format=FORMAT)
3331

3432

35-
def _save_pte_program(
36-
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
37-
) -> None:
38-
if model_name.endswith(".pte"):
39-
filename = model_name
40-
else:
41-
filename = os.path.join(output_dir, f"{model_name}.pte")
42-
43-
try:
44-
with open(filename, "wb") as file:
45-
prog.write_to_file(file)
46-
logging.info(f"Saved exported program to {filename}")
47-
except Exception as e:
48-
logging.error(f"Error while saving to {filename}: {e}")
49-
50-
51-
def _save_bpte_program(
52-
buffer: bytes,
53-
model_name: str,
54-
output_dir: str = "",
55-
) -> None:
56-
if model_name.endswith(".bpte"):
57-
filename = model_name
58-
else:
59-
filename = os.path.join(output_dir, f"{model_name}.bpte")
60-
try:
61-
with open(filename, "wb") as f:
62-
f.write(buffer)
63-
logging.info(f"Saved exported program to {filename}")
64-
except Exception as e:
65-
logging.error(f"Error while saving to {output_dir}: {e}")
66-
67-
6833
def export_model(
6934
model: nn.Module,
7035
example_inputs: Tuple[Any, ...],
@@ -74,32 +39,28 @@ def export_model(
7439
working_dir = tempfile.mkdtemp(dir="/tmp")
7540
logging.debug(f"Created work directory {working_dir}")
7641

77-
# convert the model (also called in quantize_pt2)
78-
converted_model = convert_pt2(model, example_inputs, CadenceQuantizer())
42+
# Instantiate the quantizer
43+
quantizer = CadenceQuantizer()
7944

80-
# Get reference outputs from quantized_model
81-
ref_outputs = converted_model(*example_inputs)
45+
# Convert the model
46+
converted_model = convert_pt2(model, example_inputs, quantizer)
8247

83-
# Quantize the model
84-
quantized_model = quantize_pt2(model, example_inputs)
48+
# Get reference outputs from converted model
49+
ref_outputs = converted_model(*example_inputs)
8550

86-
# Get edge program (also called in export_to_cadence)
87-
edge_prog_manager = export_to_edge(quantized_model, example_inputs)
51+
# Quantize the model (note: quantizer needs to be the same as
52+
# the one used in convert_pt2)
53+
quantized_model = fuse_pt2(converted_model, quantizer)
8854

8955
# Get edge program after Cadence specific passes
9056
cadence_prog_manager = export_to_cadence(quantized_model, example_inputs)
9157

58+
# Get executorch program after Cadence specific passes
9259
exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch()
9360

9461
logging.info("Final exported graph:\n")
9562
exec_prog.exported_program().graph_module.graph.print_tabular()
9663

97-
# Print some information to terminal
98-
print_ops_info(
99-
edge_prog_manager.exported_program().graph_module,
100-
cadence_prog_manager.exported_program().graph_module,
101-
)
102-
10364
forward_test_data = BundledProgramManager.bundled_program_test_data_gen(
10465
method="forward", inputs=example_inputs, expected_outputs=ref_outputs
10566
)
@@ -110,9 +71,9 @@ def export_model(
11071
forward_test_data,
11172
)
11273
# Save the program as pte (default name is CadenceDemoModel.pte)
113-
_save_pte_program(exec_prog, file_name, working_dir)
74+
save_pte_program(exec_prog, file_name, working_dir)
11475
# Save the program as btpe (default name is CadenceDemoModel.bpte)
115-
_save_bpte_program(buffer, file_name, working_dir)
76+
save_bpte_program(buffer, file_name, working_dir)
11677

11778
logging.debug(
11879
f"Executorch bundled program buffer saved to {file_name} is {len(buffer)} total bytes"

backends/cadence/aot/utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
import logging
1010
import operator
11+
import os
1112
from typing import Dict, List, Tuple
1213

1314
import torch
14-
from executorch.exir import memory
15+
16+
from executorch.exir import ExecutorchProgramManager, memory
1517
from executorch.exir.dialects._ops import ops as exir_ops
1618
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
1719
from tabulate import tabulate
@@ -185,3 +187,36 @@ def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool:
185187
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
186188
return True
187189
return False
190+
191+
192+
def save_pte_program(
193+
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
194+
) -> None:
195+
if model_name.endswith(".pte"):
196+
filename = model_name
197+
else:
198+
filename = os.path.join(output_dir, f"{model_name}.pte")
199+
200+
try:
201+
with open(filename, "wb") as file:
202+
prog.write_to_file(file)
203+
logging.info(f"Saved exported program to {filename}")
204+
except Exception as e:
205+
logging.error(f"Error while saving to {filename}: {e}")
206+
207+
208+
def save_bpte_program(
209+
buffer: bytes,
210+
model_name: str,
211+
output_dir: str = "",
212+
) -> None:
213+
if model_name.endswith(".bpte"):
214+
filename = model_name
215+
else:
216+
filename = os.path.join(output_dir, f"{model_name}.bpte")
217+
try:
218+
with open(filename, "wb") as f:
219+
f.write(buffer)
220+
logging.info(f"Saved exported program to {filename}")
221+
except Exception as e:
222+
logging.error(f"Error while saving to {output_dir}: {e}")

0 commit comments

Comments
 (0)