Skip to content

Commit 2856654

Browse files
tarun292facebook-github-bot
authored andcommitted
Add support for EdgeProgramManager and ExecutorchProgramManager in etrecord (#788)
Summary: Pull Request resolved: #788 Reviewed By: Jack-Khuu Differential Revision: D50129349 Pulled By: tarun292 fbshipit-source-id: e0db415423063887f29ce1ac6980c094e914470d
1 parent 9ce7fb0 commit 2856654

File tree

6 files changed

+68
-63
lines changed

6 files changed

+68
-63
lines changed

docs/source/sdk-etrecord.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ will be expected to provide the Edge Dialect program (returned by the call to ``
3434
the ExecuTorch program (returned by the call to ``to_executorch()``), and optional models that
3535
they are interested in working with via our tooling.
3636

37+
.. warning::
38+
Users should do a deepcopy of the output of to_edge() and pass in the deepcopy to the generate_etrecord API. This is needed because the subsequent call, to_executorch(), does an in-place mutation and will lose debug data in the process.
39+
3740
.. currentmodule:: sdk.etrecord._etrecord
3841
.. autofunction:: generate_etrecord
3942

exir/program/_program.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,9 +878,14 @@ def dump_executorch_program(self, verbose: bool = False) -> None:
878878

879879
@property
880880
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
881-
# TODO ask Tarun what the docstring here should be.
882881
return self._emitter_output.debug_handle_map
883882

883+
@property
884+
def delegate_map(self) -> Dict[
885+
str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]
886+
]:
887+
return self._emitter_output.method_to_delegate_debug_id_map
888+
884889
@property
885890
def executorch_program(self) -> Program:
886891
"""

sdk/etrecord/_etrecord.py

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
from typing import Dict, List, Optional, Union
1111
from zipfile import BadZipFile, ZipFile
1212

13+
import executorch
1314
from executorch.exir import (
15+
EdgeProgramManager,
1416
ExecutorchProgram,
17+
ExecutorchProgramManager,
1518
ExirExportedProgram,
1619
ExportedProgram,
1720
MultiMethodExecutorchProgram,
@@ -65,7 +68,7 @@ def _handle_multi_method_exported_program(
6568

6669
def _handle_export_module(
6770
etrecord_zip: ZipFile,
68-
export_module: Union[MultiMethodExirExportedProgram, ExirExportedProgram],
71+
export_module: Union[MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager],
6972
module_name: str,
7073
) -> None:
7174
if isinstance(export_module, MultiMethodExirExportedProgram):
@@ -74,45 +77,14 @@ def _handle_export_module(
7477
_handle_exported_program(
7578
etrecord_zip, module_name, "forward", export_module.exported_program
7679
)
80+
elif isinstance(export_module, (EdgeProgramManager, executorch.exir.program._program.EdgeProgramManager)):
81+
for method in export_module.methods:
82+
_handle_exported_program(
83+
etrecord_zip, module_name, method, export_module.exported_program(method)
84+
)
7785
else:
7886
raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")
7987

80-
81-
def _handle_executorch_program(
82-
etrecord_zip: ZipFile,
83-
program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
84-
) -> None:
85-
if isinstance(program, MultiMethodExecutorchProgram):
86-
# Do a dummy read of the program here to make sure that the emitter runs
87-
# under the hood which will result in the debug handle map being generated.
88-
program.program
89-
90-
_handle_multi_method_exported_program(
91-
etrecord_zip,
92-
ETRecordReservedFileNames.ET_DIALECT_GRAPH_MODULE,
93-
program._executorch_dialect_ir_program,
94-
)
95-
96-
elif isinstance(program, ExecutorchProgram):
97-
# Do a dummy read of the program here to make sure that the emitter runs
98-
# under the hood which will result in the debug handle map being generated.
99-
program.program
100-
101-
_handle_exported_program(
102-
etrecord_zip,
103-
ETRecordReservedFileNames.ET_DIALECT_GRAPH_MODULE,
104-
"forward",
105-
program.dump_exported_program(),
106-
)
107-
108-
etrecord_zip.writestr(ETRecordReservedFileNames.PROGRAM_BUFFER, program.buffer)
109-
110-
else:
111-
raise RuntimeError(
112-
f"program passed in should be either ExecutorchProgram or MultiMethodExecutorchProgram. {type(program)}"
113-
)
114-
115-
11688
def _handle_edge_dialect_exported_program(
11789
etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram
11890
) -> None:
@@ -130,12 +102,12 @@ def _handle_edge_dialect_exported_program(
130102

131103
def generate_etrecord(
132104
etrecord_path: str,
133-
edge_dialect_program: ExirExportedProgram,
134-
executorch_program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
105+
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
106+
executorch_program: Union[ExecutorchProgram, MultiMethodExecutorchProgram, ExecutorchProgramManager],
135107
export_modules: Optional[
136108
Dict[
137109
str,
138-
Union[MultiMethodExirExportedProgram, ExirExportedProgram],
110+
Union[MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager],
139111
]
140112
] = None,
141113
) -> None:
@@ -151,10 +123,9 @@ def generate_etrecord(
151123
152124
Args:
153125
etrecord_path: Path to where the `ETRecord` file will be saved to.
154-
edge_dialect_program: `ExirExportedProgram` for this model returned by the call to to_edge()
155-
executorch_program: `ExecutorchProgram` or `MultiMethodExecutorchProgram` for this model returned by the
156-
call to `to_executorch()`
157-
export_modules: A dictionary of graph modules with the key being the user provided name and the
126+
edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge()
127+
executorch_program: `ExecutorchProgramManager` for this model returned by the call to `to_executorch()`
128+
export_modules[Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
158129
value being the corresponding exported module. The exported graph modules can be either the
159130
output of `capture()` or `to_edge()`.
160131
@@ -179,12 +150,19 @@ def generate_etrecord(
179150
)
180151
_handle_export_module(etrecord_zip, export_module, module_name)
181152

182-
_handle_executorch_program(etrecord_zip, executorch_program)
153+
if isinstance(edge_dialect_program, (EdgeProgramManager, executorch.exir.program._program.EdgeProgramManager)):
154+
_handle_edge_dialect_exported_program(
155+
etrecord_zip,
156+
edge_dialect_program.exported_program(),
157+
)
158+
elif isinstance(edge_dialect_program, ExirExportedProgram):
159+
_handle_edge_dialect_exported_program(
160+
etrecord_zip,
161+
edge_dialect_program.exported_program,
162+
)
163+
else:
164+
raise RuntimeError(f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}.")
183165

184-
_handle_edge_dialect_exported_program(
185-
etrecord_zip,
186-
edge_dialect_program.exported_program,
187-
)
188166

189167
etrecord_zip.writestr(
190168
ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,

sdk/etrecord/tests/etrecord_test.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
import executorch.exir.tests.models as models
1313
from executorch import exir
14+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
1415
from executorch.sdk.etrecord import generate_etrecord, parse_etrecord
1516
from executorch.sdk.etrecord._etrecord import ETRecordReservedFileNames
17+
from torch.export import export
1618

1719

1820
# TODO : T154728484 Add test cases to cover multiple entry points
@@ -28,7 +30,14 @@ def get_test_model(self):
2830
edge_output_copy = copy.deepcopy(edge_output)
2931
et_output = edge_output.to_executorch()
3032
buffer = et_output.buffer
31-
return (captured_output_copy, edge_output_copy, et_output, buffer)
33+
return (captured_output_copy, edge_output_copy, et_output)
34+
35+
def get_test_model_with_manager(self):
36+
f = models.BasicSinMax()
37+
aten_dialect = export(f, f.get_random_inputs())
38+
edge_program: EdgeProgramManager = to_edge(aten_dialect, compile_config = EdgeCompileConfig(_check_ir_validity=False))
39+
edge_program_copy = copy.deepcopy(edge_program)
40+
return (aten_dialect, edge_program_copy, edge_program.to_executorch())
3241

3342
# Serialized and deserialized graph modules are not completely the same, so we check
3443
# that they are close enough and match especially on the parameters we care about in the SDK.
@@ -47,7 +56,7 @@ def check_graph_closeness(self, graph_a, graph_b):
4756
)
4857

4958
def test_etrecord_generation(self):
50-
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
59+
captured_output, edge_output, et_output = self.get_test_model()
5160
with tempfile.TemporaryDirectory() as tmpdirname:
5261
generate_etrecord(
5362
tmpdirname + "/etrecord.bin",
@@ -67,18 +76,32 @@ def test_etrecord_generation(self):
6776
etrecord.edge_dialect_program,
6877
edge_output.exported_program.graph_module,
6978
)
79+
self.assertEqual(
80+
etrecord._debug_handle_map,
81+
json.loads(json.dumps(et_output.debug_handle_map)),
82+
)
83+
84+
def test_etrecord_generation_with_manager(self):
85+
captured_output, edge_output, et_output = self.get_test_model_with_manager()
86+
with tempfile.TemporaryDirectory() as tmpdirname:
87+
generate_etrecord(
88+
tmpdirname + "/etrecord.bin",
89+
edge_output,
90+
et_output,
91+
)
92+
93+
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
7094
self.check_graph_closeness(
71-
etrecord.graph_map["et_dialect_graph_module/forward"],
72-
et_output.dump_exported_program(),
95+
etrecord.edge_dialect_program,
96+
edge_output.exported_program().graph_module,
7397
)
7498
self.assertEqual(
7599
etrecord._debug_handle_map,
76100
json.loads(json.dumps(et_output.debug_handle_map)),
77101
)
78-
self.assertEqual(etrecord.program_buffer, program_buffer)
79102

80103
def test_etrecord_invalid_input(self):
81-
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
104+
captured_output, edge_output, et_output = self.get_test_model()
82105
with tempfile.TemporaryDirectory() as tmpdirname:
83106
with self.assertRaises(RuntimeError):
84107
generate_etrecord(
@@ -89,7 +112,7 @@ def test_etrecord_invalid_input(self):
89112
)
90113

91114
def test_etrecord_reserved_name(self):
92-
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
115+
captured_output, edge_output, et_output = self.get_test_model()
93116
with tempfile.TemporaryDirectory() as tmpdirname:
94117
for reserved_name in ETRecordReservedFileNames:
95118
with self.assertRaises(RuntimeError):

sdk/inspector/tests/inspector_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_inspector_get_exported_program(self):
197197
)
198198

199199
# Gen a mock etrecord
200-
captured_output, edge_output, et_output, _ = TestETRecord().get_test_model()
200+
captured_output, edge_output, et_output = TestETRecord().get_test_model()
201201
with tempfile.TemporaryDirectory() as tmpdirname:
202202
generate_etrecord(
203203
tmpdirname + "/etrecord.bin",

sdk/inspector/tests/inspector_utils_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
class TestInspectorUtils(unittest.TestCase):
2929
def test_gen_graphs_from_etrecord(self):
30-
captured_output, edge_output, et_output, _ = TestETRecord().get_test_model()
30+
captured_output, edge_output, et_output = TestETRecord().get_test_model()
3131
with tempfile.TemporaryDirectory() as tmpdirname:
3232
generate_etrecord(
3333
tmpdirname + "/etrecord.bin",
@@ -43,15 +43,11 @@ def test_gen_graphs_from_etrecord(self):
4343
graphs = gen_graphs_from_etrecord(etrecord)
4444

4545
self.assertTrue("aten_dialect_output/forward" in graphs)
46-
self.assertTrue("et_dialect_graph_module/forward" in graphs)
4746
self.assertTrue(EDGE_DIALECT_GRAPH_KEY in graphs)
4847

4948
self.assertTrue(
5049
isinstance(graphs["aten_dialect_output/forward"], FXOperatorGraph)
5150
)
52-
self.assertTrue(
53-
isinstance(graphs["et_dialect_graph_module/forward"], FXOperatorGraph)
54-
)
5551
self.assertTrue(isinstance(graphs[EDGE_DIALECT_GRAPH_KEY], FXOperatorGraph))
5652

5753
def test_create_debug_handle_to_op_node_mapping(self):

0 commit comments

Comments
 (0)