Skip to content

Commit e6deec6

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Add export_to_edge util
Summary: Adding export_to_edge util to avoid possiblity of using different configs for exir.capture and to_edge. Reviewed By: kirklandsign Differential Revision: D48435836 fbshipit-source-id: f7029d04bfb845cf1ec4ad919b44b7defdfc6c83
1 parent d0cb851 commit e6deec6

File tree

4 files changed

+18
-22
lines changed

4 files changed

+18
-22
lines changed

examples/export/export_and_delegate.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from ..models import MODEL_NAME_TO_MODEL
2020

21-
from .utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
21+
from .utils import export_to_edge
2222

2323
"""
2424
BackendWithCompilerDemo is a test demo backend, only supports torch.mm and torch.add, here are some examples
@@ -52,7 +52,7 @@ def export_compsite_module_with_lower_graph():
5252
m, m_inputs = MODEL_NAME_TO_MODEL.get("add_mul")()
5353
m = m.eval()
5454
m_inputs = m.get_example_inputs()
55-
edge = exir.capture(m, m_inputs, _CAPTURE_CONFIG).to_edge(_EDGE_COMPILE_CONFIG)
55+
edge = export_to_edge(m, m_inputs)
5656
print("Exported graph:\n", edge.exported_program.graph)
5757

5858
# Lower AddMulModule to the demo backend
@@ -71,11 +71,7 @@ def forward(self, *args):
7171
return torch.sub(self.lowered_graph(*args), torch.ones(1))
7272

7373
# Get the graph for the composite module, which includes lowered graph
74-
composited_edge = exir.capture(
75-
CompositeModule(),
76-
m_inputs,
77-
_CAPTURE_CONFIG,
78-
).to_edge(_EDGE_COMPILE_CONFIG)
74+
composited_edge = export_to_edge(CompositeModule(), m_inputs)
7975

8076
# The graph module is still runnerable
8177
composited_edge.exported_program.graph_module(*m_inputs)
@@ -122,9 +118,7 @@ def get_example_inputs(self):
122118
return (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
123119

124120
m = Model()
125-
edge = exir.capture(m, m.get_example_inputs(), _CAPTURE_CONFIG).to_edge(
126-
_EDGE_COMPILE_CONFIG
127-
)
121+
edge = export_to_edge(m, m.get_example_inputs())
128122
print("Exported graph:\n", edge.exported_program.graph)
129123

130124
# Lower to backend_with_compiler_demo
@@ -158,7 +152,7 @@ def export_and_lower_the_whole_graph():
158152
m, m_inputs = MODEL_NAME_TO_MODEL.get("add_mul")()
159153
m = m.eval()
160154
m_inputs = m.get_example_inputs()
161-
edge = exir.capture(m, m_inputs, _CAPTURE_CONFIG).to_edge(_EDGE_COMPILE_CONFIG)
155+
edge = export_to_edge(m, m_inputs)
162156
print("Exported graph:\n", edge.exported_program.graph)
163157

164158
# Lower AddMulModule to the demo backend

examples/export/export_example.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,11 @@
1212

1313
from ..models import MODEL_NAME_TO_MODEL
1414

15-
from .utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
15+
from .utils import export_to_edge
1616

1717

1818
def export_to_pte(model_name, model, example_inputs):
19-
m = model.eval()
20-
edge = exir.capture(m, example_inputs, _CAPTURE_CONFIG).to_edge(
21-
_EDGE_COMPILE_CONFIG
22-
)
23-
print("Exported graph:\n", edge.exported_program.graph)
24-
19+
edge = export_to_edge(model, example_inputs)
2520
exec_prog = edge.to_executorch()
2621

2722
buffer = exec_prog.buffer

examples/export/test/test_export.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212

13-
from executorch.examples.export.utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
13+
from executorch.examples.export.utils import export_to_edge
1414
from executorch.examples.models import MODEL_NAME_TO_MODEL
1515

1616
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
@@ -34,9 +34,7 @@ def _assert_eager_lowered_same_result(
3434
"""
3535
import executorch.exir as exir
3636

37-
edge_model = exir.capture(eager_model, example_inputs, _CAPTURE_CONFIG).to_edge(
38-
_EDGE_COMPILE_CONFIG
39-
)
37+
edge_model = export_to_edge(eager_model, example_inputs)
4038

4139
executorch_prog = edge_model.to_executorch()
4240
# pyre-ignore

examples/export/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,12 @@
1515
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
1616
_check_ir_validity=False,
1717
)
18+
19+
20+
def export_to_edge(model, example_inputs):
21+
m = model.eval()
22+
edge = exir.capture(m, example_inputs, _CAPTURE_CONFIG).to_edge(
23+
_EDGE_COMPILE_CONFIG
24+
)
25+
print("Exported graph:\n", edge.exported_program.graph)
26+
return edge

0 commit comments

Comments
 (0)