Skip to content

Commit c20df2a

Browse files
tarun292facebook-github-bot
authored andcommitted
Add support for EdgeProgramManager and ExecutorchProgramManager in etrecord (#788)
Summary: Pull Request resolved: #788 Reviewed By: Jack-Khuu Differential Revision: D50129349 Pulled By: tarun292 fbshipit-source-id: e3987dba9f464f4c068508641e662def873d7fd4
1 parent 592ac12 commit c20df2a

File tree

7 files changed

+89
-64
lines changed

7 files changed

+89
-64
lines changed

docs/source/sdk-etrecord.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ will be expected to provide the Edge Dialect program (returned by the call to ``
3434
the ExecuTorch program (returned by the call to ``to_executorch()``), and optional models that
3535
they are interested in working with via our tooling.
3636

37+
.. warning::
38+
Users should do a deepcopy of the output of to_edge() and pass in the deepcopy to the generate_etrecord API. This is needed because the subsequent call, to_executorch(), does an in-place mutation and will lose debug data in the process.
39+
3740
.. currentmodule:: sdk.etrecord._etrecord
3841
.. autofunction:: generate_etrecord
3942

exir/program/_program.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,9 +878,14 @@ def dump_executorch_program(self, verbose: bool = False) -> None:
878878

879879
@property
880880
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
881-
# TODO ask Tarun what the docstring here should be.
882881
return self._emitter_output.debug_handle_map
883882

883+
@property
884+
def delegate_map(
885+
self,
886+
) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
887+
return self._emitter_output.method_to_delegate_debug_id_map
888+
884889
@property
885890
def executorch_program(self) -> Program:
886891
"""

sdk/etrecord/_etrecord.py

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
from typing import Dict, List, Optional, Union
1111
from zipfile import BadZipFile, ZipFile
1212

13+
from executorch import exir
1314
from executorch.exir import (
15+
EdgeProgramManager,
1416
ExecutorchProgram,
17+
ExecutorchProgramManager,
1518
ExirExportedProgram,
1619
ExportedProgram,
1720
MultiMethodExecutorchProgram,
@@ -65,7 +68,9 @@ def _handle_multi_method_exported_program(
6568

6669
def _handle_export_module(
6770
etrecord_zip: ZipFile,
68-
export_module: Union[MultiMethodExirExportedProgram, ExirExportedProgram],
71+
export_module: Union[
72+
MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager
73+
],
6974
module_name: str,
7075
) -> None:
7176
if isinstance(export_module, MultiMethodExirExportedProgram):
@@ -74,45 +79,21 @@ def _handle_export_module(
7479
_handle_exported_program(
7580
etrecord_zip, module_name, "forward", export_module.exported_program
7681
)
82+
elif isinstance(
83+
export_module,
84+
(EdgeProgramManager, exir.program._program.EdgeProgramManager),
85+
):
86+
for method in export_module.methods:
87+
_handle_exported_program(
88+
etrecord_zip,
89+
module_name,
90+
method,
91+
export_module.exported_program(method),
92+
)
7793
else:
7894
raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")
7995

8096

81-
def _handle_executorch_program(
82-
etrecord_zip: ZipFile,
83-
program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
84-
) -> None:
85-
if isinstance(program, MultiMethodExecutorchProgram):
86-
# Do a dummy read of the program here to make sure that the emitter runs
87-
# under the hood which will result in the debug handle map being generated.
88-
program.program
89-
90-
_handle_multi_method_exported_program(
91-
etrecord_zip,
92-
ETRecordReservedFileNames.ET_DIALECT_GRAPH_MODULE,
93-
program._executorch_dialect_ir_program,
94-
)
95-
96-
elif isinstance(program, ExecutorchProgram):
97-
# Do a dummy read of the program here to make sure that the emitter runs
98-
# under the hood which will result in the debug handle map being generated.
99-
program.program
100-
101-
_handle_exported_program(
102-
etrecord_zip,
103-
ETRecordReservedFileNames.ET_DIALECT_GRAPH_MODULE,
104-
"forward",
105-
program.dump_exported_program(),
106-
)
107-
108-
etrecord_zip.writestr(ETRecordReservedFileNames.PROGRAM_BUFFER, program.buffer)
109-
110-
else:
111-
raise RuntimeError(
112-
f"program passed in should be either ExecutorchProgram or MultiMethodExecutorchProgram. {type(program)}"
113-
)
114-
115-
11697
def _handle_edge_dialect_exported_program(
11798
etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram
11899
) -> None:
@@ -130,12 +111,16 @@ def _handle_edge_dialect_exported_program(
130111

131112
def generate_etrecord(
132113
etrecord_path: str,
133-
edge_dialect_program: ExirExportedProgram,
134-
executorch_program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
114+
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
115+
executorch_program: Union[
116+
ExecutorchProgram, MultiMethodExecutorchProgram, ExecutorchProgramManager
117+
],
135118
export_modules: Optional[
136119
Dict[
137120
str,
138-
Union[MultiMethodExirExportedProgram, ExirExportedProgram],
121+
Union[
122+
MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager
123+
],
139124
]
140125
] = None,
141126
) -> None:
@@ -151,10 +136,9 @@ def generate_etrecord(
151136
152137
Args:
153138
etrecord_path: Path to where the `ETRecord` file will be saved to.
154-
edge_dialect_program: `ExirExportedProgram` for this model returned by the call to to_edge()
155-
executorch_program: `ExecutorchProgram` or `MultiMethodExecutorchProgram` for this model returned by the
156-
call to `to_executorch()`
157-
export_modules: A dictionary of graph modules with the key being the user provided name and the
139+
edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge()
140+
executorch_program: `ExecutorchProgramManager` for this model returned by the call to `to_executorch()`
141+
export_modules[Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
158142
value being the corresponding exported module. The exported graph modules can be either the
159143
output of `capture()` or `to_edge()`.
160144
@@ -179,12 +163,23 @@ def generate_etrecord(
179163
)
180164
_handle_export_module(etrecord_zip, export_module, module_name)
181165

182-
_handle_executorch_program(etrecord_zip, executorch_program)
183-
184-
_handle_edge_dialect_exported_program(
185-
etrecord_zip,
186-
edge_dialect_program.exported_program,
187-
)
166+
if isinstance(
167+
edge_dialect_program,
168+
(EdgeProgramManager, exir.program._program.EdgeProgramManager),
169+
):
170+
_handle_edge_dialect_exported_program(
171+
etrecord_zip,
172+
edge_dialect_program.exported_program(),
173+
)
174+
elif isinstance(edge_dialect_program, ExirExportedProgram):
175+
_handle_edge_dialect_exported_program(
176+
etrecord_zip,
177+
edge_dialect_program.exported_program,
178+
)
179+
else:
180+
raise RuntimeError(
181+
f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}."
182+
)
188183

189184
etrecord_zip.writestr(
190185
ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,

sdk/etrecord/tests/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ python_unittest(
77
name = "etrecord_test",
88
srcs = ["etrecord_test.py"],
99
deps = [
10+
"//caffe2:torch",
1011
"//executorch/exir:lib",
1112
"//executorch/exir/tests:models",
1213
"//executorch/sdk/etrecord:etrecord",
@@ -17,6 +18,7 @@ python_library(
1718
name = "etrecord_test_library",
1819
srcs = ["etrecord_test.py"],
1920
deps = [
21+
"//caffe2:torch",
2022
"//executorch/exir:lib",
2123
"//executorch/exir/tests:models",
2224
"//executorch/sdk/etrecord:etrecord",

sdk/etrecord/tests/etrecord_test.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
import executorch.exir.tests.models as models
1313
from executorch import exir
14+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
1415
from executorch.sdk.etrecord import generate_etrecord, parse_etrecord
1516
from executorch.sdk.etrecord._etrecord import ETRecordReservedFileNames
17+
from torch.export import export
1618

1719

1820
# TODO : T154728484 Add test cases to cover multiple entry points
@@ -27,8 +29,16 @@ def get_test_model(self):
2729
)
2830
edge_output_copy = copy.deepcopy(edge_output)
2931
et_output = edge_output.to_executorch()
30-
buffer = et_output.buffer
31-
return (captured_output_copy, edge_output_copy, et_output, buffer)
32+
return (captured_output_copy, edge_output_copy, et_output)
33+
34+
def get_test_model_with_manager(self):
35+
f = models.BasicSinMax()
36+
aten_dialect = export(f, f.get_random_inputs())
37+
edge_program: EdgeProgramManager = to_edge(
38+
aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)
39+
)
40+
edge_program_copy = copy.deepcopy(edge_program)
41+
return (aten_dialect, edge_program_copy, edge_program.to_executorch())
3242

3343
# Serialized and deserialized graph modules are not completely the same, so we check
3444
# that they are close enough and match especially on the parameters we care about in the SDK.
@@ -47,7 +57,7 @@ def check_graph_closeness(self, graph_a, graph_b):
4757
)
4858

4959
def test_etrecord_generation(self):
50-
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
60+
captured_output, edge_output, et_output = self.get_test_model()
5161
with tempfile.TemporaryDirectory() as tmpdirname:
5262
generate_etrecord(
5363
tmpdirname + "/etrecord.bin",
@@ -67,18 +77,32 @@ def test_etrecord_generation(self):
6777
etrecord.edge_dialect_program,
6878
edge_output.exported_program.graph_module,
6979
)
80+
self.assertEqual(
81+
etrecord._debug_handle_map,
82+
json.loads(json.dumps(et_output.debug_handle_map)),
83+
)
84+
85+
def test_etrecord_generation_with_manager(self):
86+
captured_output, edge_output, et_output = self.get_test_model_with_manager()
87+
with tempfile.TemporaryDirectory() as tmpdirname:
88+
generate_etrecord(
89+
tmpdirname + "/etrecord.bin",
90+
edge_output,
91+
et_output,
92+
)
93+
94+
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
7095
self.check_graph_closeness(
71-
etrecord.graph_map["et_dialect_graph_module/forward"],
72-
et_output.dump_exported_program(),
96+
etrecord.edge_dialect_program,
97+
edge_output.exported_program().graph_module,
7398
)
7499
self.assertEqual(
75100
etrecord._debug_handle_map,
76101
json.loads(json.dumps(et_output.debug_handle_map)),
77102
)
78-
self.assertEqual(etrecord.program_buffer, program_buffer)
79103

80104
def test_etrecord_invalid_input(self):
81-
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
105+
captured_output, edge_output, et_output = self.get_test_model()
82106
with tempfile.TemporaryDirectory() as tmpdirname:
83107
with self.assertRaises(RuntimeError):
84108
generate_etrecord(
@@ -89,7 +113,7 @@ def test_etrecord_invalid_input(self):
89113
)
90114

91115
def test_etrecord_reserved_name(self):
92-
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
116+
captured_output, edge_output, et_output = self.get_test_model()
93117
with tempfile.TemporaryDirectory() as tmpdirname:
94118
for reserved_name in ETRecordReservedFileNames:
95119
with self.assertRaises(RuntimeError):

sdk/inspector/tests/inspector_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_inspector_get_exported_program(self):
197197
)
198198

199199
# Gen a mock etrecord
200-
captured_output, edge_output, et_output, _ = TestETRecord().get_test_model()
200+
captured_output, edge_output, et_output = TestETRecord().get_test_model()
201201
with tempfile.TemporaryDirectory() as tmpdirname:
202202
generate_etrecord(
203203
tmpdirname + "/etrecord.bin",

sdk/inspector/tests/inspector_utils_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
class TestInspectorUtils(unittest.TestCase):
2929
def test_gen_graphs_from_etrecord(self):
30-
captured_output, edge_output, et_output, _ = TestETRecord().get_test_model()
30+
captured_output, edge_output, et_output = TestETRecord().get_test_model()
3131
with tempfile.TemporaryDirectory() as tmpdirname:
3232
generate_etrecord(
3333
tmpdirname + "/etrecord.bin",
@@ -43,15 +43,11 @@ def test_gen_graphs_from_etrecord(self):
4343
graphs = gen_graphs_from_etrecord(etrecord)
4444

4545
self.assertTrue("aten_dialect_output/forward" in graphs)
46-
self.assertTrue("et_dialect_graph_module/forward" in graphs)
4746
self.assertTrue(EDGE_DIALECT_GRAPH_KEY in graphs)
4847

4948
self.assertTrue(
5049
isinstance(graphs["aten_dialect_output/forward"], FXOperatorGraph)
5150
)
52-
self.assertTrue(
53-
isinstance(graphs["et_dialect_graph_module/forward"], FXOperatorGraph)
54-
)
5551
self.assertTrue(isinstance(graphs[EDGE_DIALECT_GRAPH_KEY], FXOperatorGraph))
5652

5753
def test_create_debug_handle_to_op_node_mapping(self):

0 commit comments

Comments
 (0)