Skip to content

Commit 71470a7

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Move examples export to two stage export (#238)
Summary: Pull Request resolved: #238 Following the alignment, this diff moves examples export to two stage export api. Will follow up with changes to quantization examples. However, before landing, must land the torch nightly update to executorch Reviewed By: guangy10 Differential Revision: D49025972 fbshipit-source-id: 0b2c71dd7d562a714ff723b87b660d564a96cc37
1 parent 27dabc5 commit 71470a7

File tree

7 files changed

+80
-27
lines changed

7 files changed

+80
-27
lines changed

examples/backend/xnnpack_examples.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
import argparse
1010
import logging
1111

12+
import torch._export as export
13+
1214
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
1315
XnnpackFloatingPointPartitioner,
1416
XnnpackQuantizedPartitioner,
1517
)
16-
1718
from executorch.exir import CaptureConfig, EdgeCompileConfig
1819
from executorch.exir.backend.backend_api import to_backend
1920
from executorch.exir.backend.canonical_partitioners.duplicate_dequant_node_pass import (
@@ -77,6 +78,8 @@
7778
)
7879

7980
model = model.eval()
81+
# pre-autograd export. eventually this will become torch.export
82+
model = export.capture_pre_autograd_graph(model, example_inputs)
8083

8184
partitioner = XnnpackFloatingPointPartitioner
8285
if args.quantize:
@@ -85,10 +88,11 @@
8588
# TODO(T161849167): Partitioner will eventually be a single partitioner for both fp32 and quantized models
8689
partitioner = XnnpackQuantizedPartitioner
8790

91+
capture_config = CaptureConfig(enable_aot=True)
92+
8893
edge = export_to_edge(
8994
model,
9095
example_inputs,
91-
capture_config=CaptureConfig(enable_aot=True),
9296
edge_compile_config=EdgeCompileConfig(
9397
# TODO(T162080278): Duplicated Dequant nodes will be in quantizer spec
9498
_check_ir_validity=False

examples/export/export_and_delegate.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111

1212
import torch
13+
import torch._export as export
1314
from executorch.exir.backend.backend_api import to_backend
1415
from executorch.exir.backend.test.backend_with_compiler_demo import (
1516
BackendWithCompilerDemo,
@@ -18,7 +19,7 @@
1819

1920
from ..models import MODEL_NAME_TO_MODEL
2021

21-
from .utils import export_to_edge
22+
from ..utils import export_to_edge
2223

2324

2425
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -59,6 +60,8 @@ def export_compsite_module_with_lower_graph():
5960
m, m_inputs = MODEL_NAME_TO_MODEL.get("add_mul")()
6061
m = m.eval()
6162
m_inputs = m.get_example_inputs()
63+
# pre-autograd export. eventually this will become torch.export
64+
m = export.capture_pre_autograd_graph(m, m_inputs)
6265
edge = export_to_edge(m, m_inputs)
6366
logging.info(f"Exported graph:\n{edge.exported_program.graph}")
6467

@@ -78,7 +81,11 @@ def forward(self, *args):
7881
return torch.sub(self.lowered_graph(*args), torch.ones(1))
7982

8083
# Get the graph for the composite module, which includes lowered graph
81-
composited_edge = export_to_edge(CompositeModule(), m_inputs)
84+
m = CompositeModule()
85+
m = m.eval()
86+
# pre-autograd export. eventually this will become torch.export
87+
m = export.capture_pre_autograd_graph(m, m_inputs)
88+
composited_edge = export_to_edge(m, m_inputs)
8289

8390
# The graph module is still runnerable
8491
composited_edge.exported_program.graph_module(*m_inputs)
@@ -125,7 +132,10 @@ def get_example_inputs(self):
125132
return (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
126133

127134
m = Model()
128-
edge = export_to_edge(m, m.get_example_inputs())
135+
m_inputs = m.get_example_inputs()
136+
# pre-autograd export. eventually this will become torch.export
137+
m = export.capture_pre_autograd_graph(m, m_inputs)
138+
edge = export_to_edge(m, m_inputs)
129139
logging.info(f"Exported graph:\n{edge.exported_program.graph}")
130140

131141
# Lower to backend_with_compiler_demo
@@ -159,6 +169,8 @@ def export_and_lower_the_whole_graph():
159169
m, m_inputs = MODEL_NAME_TO_MODEL.get("add_mul")()
160170
m = m.eval()
161171
m_inputs = m.get_example_inputs()
172+
# pre-autograd export. eventually this will become torch.export
173+
m = export.capture_pre_autograd_graph(m, m_inputs)
162174
edge = export_to_edge(m, m_inputs)
163175
logging.info(f"Exported graph:\n{edge.exported_program.graph}")
164176

examples/export/test/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ python_unittest(
1010
"//caffe2:torch",
1111
"//executorch/examples/export:utils",
1212
"//executorch/examples/models:models",
13-
"//executorch/exir:lib",
1413
"//executorch/extension/pybindings:portable_lib", # @manual
1514
],
1615
)

examples/export/test/test_export.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, Callable
1010

1111
import torch
12+
import torch._export as export
1213

1314
from executorch.examples.export.utils import export_to_edge
1415
from executorch.examples.models import MODEL_NAME_TO_MODEL
@@ -32,9 +33,10 @@ def _assert_eager_lowered_same_result(
3233
takes the eager mode output and ET output, and returns True if they
3334
match.
3435
"""
35-
import executorch.exir as exir
3636

37-
edge_model = export_to_edge(eager_model, example_inputs)
37+
eager_model = eager_model.eval()
38+
model = export.capture_pre_autograd_graph(eager_model, example_inputs)
39+
edge_model = export_to_edge(model, example_inputs)
3840

3941
executorch_prog = edge_model.to_executorch()
4042

examples/export/utils.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,15 @@
66

77
import logging
88

9+
from typing import Tuple
10+
911
import executorch.exir as exir
1012

13+
import torch
14+
import torch._export as export
15+
from executorch.exir.program import ExirExportedProgram
16+
from executorch.exir.tracer import Value
17+
1118

1219
_CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True)
1320

@@ -17,26 +24,55 @@
1724
)
1825

1926

20-
def export_to_edge(
21-
model,
22-
example_inputs,
27+
def _to_core_aten(
28+
model: torch.fx.GraphModule,
29+
example_inputs: Tuple[Value, ...],
2330
capture_config=_CAPTURE_CONFIG,
31+
) -> ExirExportedProgram:
32+
# post autograd export. eventually this will become .to_core_aten
33+
if not isinstance(model, torch.fx.GraphModule):
34+
raise ValueError(
35+
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
36+
)
37+
core_aten_exir_ep = exir.capture(model, example_inputs, capture_config)
38+
logging.info(f"Core ATen graph:\n{core_aten_exir_ep.exported_program.graph}")
39+
return core_aten_exir_ep
40+
41+
42+
def _core_aten_to_edge(
43+
core_aten_exir_ep: ExirExportedProgram,
2444
edge_compile_config=_EDGE_COMPILE_CONFIG,
25-
):
26-
m = model.eval()
27-
edge = exir.capture(m, example_inputs, capture_config).to_edge(edge_compile_config)
45+
) -> ExirExportedProgram:
46+
edge = core_aten_exir_ep.to_edge(edge_compile_config)
2847
logging.info(f"Exported graph:\n{edge.exported_program.graph}")
2948
return edge
3049

3150

51+
def export_to_edge(
52+
model: torch.fx.GraphModule,
53+
example_inputs: Tuple[Value, ...],
54+
capture_config=_CAPTURE_CONFIG,
55+
edge_compile_config=_EDGE_COMPILE_CONFIG,
56+
) -> ExirExportedProgram:
57+
core_aten_exir_ep = _to_core_aten(model, example_inputs, capture_config)
58+
return _core_aten_to_edge(core_aten_exir_ep, edge_compile_config)
59+
60+
3261
def export_to_exec_prog(
3362
model,
3463
example_inputs,
3564
capture_config=_CAPTURE_CONFIG,
3665
edge_compile_config=_EDGE_COMPILE_CONFIG,
3766
backend_config=None,
3867
):
39-
edge_m = export_to_edge(model, example_inputs, capture_config, edge_compile_config)
68+
m = model.eval()
69+
# pre-autograd export. eventually this will become torch.export
70+
m = export.capture_pre_autograd_graph(m, example_inputs)
71+
72+
core_aten_exir_ep = _to_core_aten(m, example_inputs)
73+
74+
edge_m = _core_aten_to_edge(core_aten_exir_ep, edge_compile_config)
75+
4076
exec_prog = edge_m.to_executorch(backend_config)
4177
return exec_prog
4278

examples/quantization/example.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
XNNPACKQuantizer,
2929
)
3030

31-
from ..export.export_example import export_to_exec_prog, save_pte_program
31+
from ..export.utils import export_to_edge, save_pte_program
3232
from ..models import MODEL_NAME_TO_MODEL
3333
from ..models.model_factory import EagerModelFactory
3434
from ..recipes.xnnpack_optimization import MODEL_NAME_TO_OPTIONS
@@ -154,23 +154,26 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_
154154
end = time.perf_counter()
155155
# logging.info(f"Verify time: {end - start}s")
156156

157+
model = model.eval()
158+
# pre-autograd export. eventually this will become torch.export
159+
model = export.capture_pre_autograd_graph(model, example_inputs)
157160
start = time.perf_counter()
158161
quantized_model = quantize(model, example_inputs)
159162
end = time.perf_counter()
160163
# logging.info(f"Quantize time: {end - start}s")
161164

162165
# TODO[T163161310]: takes a long time to export to exec prog and save inception_v4 quantized model
163166
if args.model_name != "ic4":
167+
164168
start = time.perf_counter()
165-
prog = export_to_exec_prog(
166-
quantized_model,
167-
copy.deepcopy(example_inputs),
168-
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
169+
edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
170+
edge_m = export_to_edge(
171+
quantized_model, example_inputs, edge_compile_config=edge_compile_config
169172
)
170173
end = time.perf_counter()
171-
# logging.info(f"export_to_exec_prog time: {end - start}s")
174+
172175
start = time.perf_counter()
176+
prog = edge_m.to_executorch(None)
173177
save_pte_program(prog.buffer, f"{args.model_name}_quantized")
174178
end = time.perf_counter()
175-
# logging.info(f"save_pte_program time: {end - start}s")
176179
logging.info("finished")

examples/quantization/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import copy
88
import logging
99

10-
import torch._export as export
1110
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
1211
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
1312
get_symmetric_quantization_config,
@@ -17,14 +16,12 @@
1716

1817
def quantize(model, example_inputs):
1918
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
20-
m = model.eval()
21-
m = export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs))
22-
logging.info(f"Original model: {m}")
19+
logging.info(f"Original model: {model}")
2320
quantizer = XNNPACKQuantizer()
2421
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
2522
operator_config = get_symmetric_quantization_config(is_per_channel=False)
2623
quantizer.set_global(operator_config)
27-
m = prepare_pt2e(m, quantizer)
24+
m = prepare_pt2e(model, quantizer)
2825
# calibration
2926
m(*example_inputs)
3027
m = convert_pt2e(m)

0 commit comments

Comments
 (0)