Skip to content

Commit 7a46559

Browse files
tarun292facebook-github-bot
authored andcommitted
Add support for EdgeProgramManager and ExecutorchProgramManager in etrecord (#788)
Summary: ETRecord generation also needs to support `EdgeProgramManager` and `ExecutorchProgramManager`. This diff adds support for that. Reviewed By: Jack-Khuu Differential Revision: D50129349
1 parent b9d139c commit 7a46559

File tree

4 files changed

+65
-57
lines changed

4 files changed

+65
-57
lines changed

docs/source/sdk-etrecord.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ will be expected to provide the Edge Dialect program (returned by the call to ``
3434
the ExecuTorch program (returned by the call to ``to_executorch()``), and optional models that
3535
they are interested in working with via our tooling.
3636

37+
# NOTE : Users should do a deepcopy of the output of `to_edge()` and pass in that deepcopy into the generate_etercord API. This is needed because
38+
the subsequent call to_executorch() does an in place mutation and we will lose debug data in the process.
39+
3740
.. currentmodule:: sdk.etrecord._etrecord
3841
.. autofunction:: generate_etrecord
3942

exir/program/_program.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ def to_executorch(
195195
)
196196
return executorch_prog
197197

198+
@property
199+
def exported_program(self) -> ExportedProgram:
200+
return self.exported_program
201+
198202
def __deepcopy__(
199203
self, memo: Optional[Dict[int, Any]] = None
200204
) -> "ExirExportedProgram":
@@ -878,9 +882,14 @@ def dump_executorch_program(self, verbose: bool = False) -> None:
878882

879883
@property
880884
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
881-
# TODO ask Tarun what the docstring here should be.
882885
return self._emitter_output.debug_handle_map
883886

887+
@property
888+
def delegate_map(self) -> Dict[
889+
str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]
890+
]:
891+
return self._emitter_output.method_to_delegate_debug_id_map
892+
884893
@property
885894
def executorch_program(self) -> Program:
886895
"""

sdk/etrecord/_etrecord.py

Lines changed: 22 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
from typing import Dict, List, Optional, Union
1111
from zipfile import BadZipFile, ZipFile
1212

13+
import executorch
1314
from executorch.exir import (
15+
EdgeProgramManager,
1416
ExecutorchProgram,
17+
ExecutorchProgramManager,
1518
ExirExportedProgram,
1619
ExportedProgram,
1720
MultiMethodExecutorchProgram,
@@ -65,7 +68,7 @@ def _handle_multi_method_exported_program(
6568

6669
def _handle_export_module(
6770
etrecord_zip: ZipFile,
68-
export_module: Union[MultiMethodExirExportedProgram, ExirExportedProgram],
71+
export_module: Union[MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager],
6972
module_name: str,
7073
) -> None:
7174
if isinstance(export_module, MultiMethodExirExportedProgram):
@@ -74,45 +77,14 @@ def _handle_export_module(
7477
_handle_exported_program(
7578
etrecord_zip, module_name, "forward", export_module.exported_program
7679
)
80+
elif isinstance(export_module, (EdgeProgramManager, executorch.exir.program._program.EdgeProgramManager)):
81+
for method in export_module.methods:
82+
_handle_exported_program(
83+
etrecord_zip, module_name, method, export_module.exported_program(method)
84+
)
7785
else:
7886
raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")
7987

80-
81-
def _handle_executorch_program(
82-
etrecord_zip: ZipFile,
83-
program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
84-
) -> None:
85-
if isinstance(program, MultiMethodExecutorchProgram):
86-
# Do a dummy read of the program here to make sure that the emitter runs
87-
# under the hood which will result in the debug handle map being generated.
88-
program.program
89-
90-
_handle_multi_method_exported_program(
91-
etrecord_zip,
92-
ETRecordReservedFileNames.ET_DIALECT_GRAPH_MODULE,
93-
program._executorch_dialect_ir_program,
94-
)
95-
96-
elif isinstance(program, ExecutorchProgram):
97-
# Do a dummy read of the program here to make sure that the emitter runs
98-
# under the hood which will result in the debug handle map being generated.
99-
program.program
100-
101-
_handle_exported_program(
102-
etrecord_zip,
103-
ETRecordReservedFileNames.ET_DIALECT_GRAPH_MODULE,
104-
"forward",
105-
program.dump_exported_program(),
106-
)
107-
108-
etrecord_zip.writestr(ETRecordReservedFileNames.PROGRAM_BUFFER, program.buffer)
109-
110-
else:
111-
raise RuntimeError(
112-
f"program passed in should be either ExecutorchProgram or MultiMethodExecutorchProgram. {type(program)}"
113-
)
114-
115-
11688
def _handle_edge_dialect_exported_program(
11789
etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram
11890
) -> None:
@@ -130,12 +102,12 @@ def _handle_edge_dialect_exported_program(
130102

131103
def generate_etrecord(
132104
etrecord_path: str,
133-
edge_dialect_program: ExirExportedProgram,
134-
executorch_program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
105+
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
106+
executorch_program: Union[ExecutorchProgram, MultiMethodExecutorchProgram, ExecutorchProgramManager],
135107
export_modules: Optional[
136108
Dict[
137109
str,
138-
Union[MultiMethodExirExportedProgram, ExirExportedProgram],
110+
Union[MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager],
139111
]
140112
] = None,
141113
) -> None:
@@ -151,10 +123,9 @@ def generate_etrecord(
151123
152124
Args:
153125
etrecord_path: Path to where the `ETRecord` file will be saved to.
154-
edge_dialect_program: `ExirExportedProgram` for this model returned by the call to to_edge()
155-
executorch_program: `ExecutorchProgram` or `MultiMethodExecutorchProgram` for this model returned by the
156-
call to `to_executorch()`
157-
export_modules: A dictionary of graph modules with the key being the user provided name and the
126+
edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge()
127+
executorch_program: `ExecutorchProgramManager` for this model returned by the call to `to_executorch()`
128+
export_modules[Optional]: ***Should be ignored by OSS users***. A dictionary of graph modules with the key being the user provided name and the
158129
value being the corresponding exported module. The exported graph modules can be either the
159130
output of `capture()` or `to_edge()`.
160131
@@ -179,12 +150,14 @@ def generate_etrecord(
179150
)
180151
_handle_export_module(etrecord_zip, export_module, module_name)
181152

182-
_handle_executorch_program(etrecord_zip, executorch_program)
153+
if isinstance(edge_dialect_program, (ExirExportedProgram, EdgeProgramManager, executorch.exir.program._program.EdgeProgramManager)):
154+
_handle_edge_dialect_exported_program(
155+
etrecord_zip,
156+
edge_dialect_program.exported_program(),
157+
)
158+
else:
159+
raise RuntimeError(f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}.")
183160

184-
_handle_edge_dialect_exported_program(
185-
etrecord_zip,
186-
edge_dialect_program.exported_program,
187-
)
188161

189162
etrecord_zip.writestr(
190163
ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,

sdk/etrecord/tests/etrecord_test.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
import executorch.exir.tests.models as models
1313
from executorch import exir
14+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
1415
from executorch.sdk.etrecord import generate_etrecord, parse_etrecord
1516
from executorch.sdk.etrecord._etrecord import ETRecordReservedFileNames
17+
from torch.export import export
1618

1719

1820
# TODO : T154728484 Add test cases to cover multiple entry points
@@ -28,7 +30,14 @@ def get_test_model(self):
2830
edge_output_copy = copy.deepcopy(edge_output)
2931
et_output = edge_output.to_executorch()
3032
buffer = et_output.buffer
31-
return (captured_output_copy, edge_output_copy, et_output, buffer)
33+
return (captured_output_copy, edge_output_copy, et_output)
34+
35+
def get_test_model_with_manager(self):
36+
f = models.BasicSinMax()
37+
aten_dialect = export(f, f.get_random_inputs())
38+
edge_program: EdgeProgramManager = to_edge(aten_dialect, compile_config = EdgeCompileConfig(_check_ir_validity=False))
39+
edge_program_copy = copy.deepcopy(edge_program)
40+
return (aten_dialect, edge_program_copy, edge_program.to_executorch())
3241

3342
# Serialized and deserialized graph modules are not completely the same, so we check
3443
# that they are close enough and match especially on the parameters we care about in the SDK.
@@ -47,7 +56,7 @@ def check_graph_closeness(self, graph_a, graph_b):
4756
)
4857

4958
def test_etrecord_generation(self):
50-
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
59+
captured_output, edge_output, et_output = self.get_test_model()
5160
with tempfile.TemporaryDirectory() as tmpdirname:
5261
generate_etrecord(
5362
tmpdirname + "/etrecord.bin",
@@ -67,18 +76,32 @@ def test_etrecord_generation(self):
6776
etrecord.edge_dialect_program,
6877
edge_output.exported_program.graph_module,
6978
)
79+
self.assertEqual(
80+
etrecord._debug_handle_map,
81+
json.loads(json.dumps(et_output.debug_handle_map)),
82+
)
83+
84+
def test_etrecord_generation_with_manager(self):
85+
captured_output, edge_output, et_output = self.get_test_model_with_manager()
86+
with tempfile.TemporaryDirectory() as tmpdirname:
87+
generate_etrecord(
88+
tmpdirname + "/etrecord.bin",
89+
edge_output,
90+
et_output,
91+
)
92+
93+
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
7094
self.check_graph_closeness(
71-
etrecord.graph_map["et_dialect_graph_module/forward"],
72-
et_output.dump_exported_program(),
95+
etrecord.edge_dialect_program,
96+
edge_output.exported_program().graph_module,
7397
)
7498
self.assertEqual(
7599
etrecord._debug_handle_map,
76100
json.loads(json.dumps(et_output.debug_handle_map)),
77101
)
78-
self.assertEqual(etrecord.program_buffer, program_buffer)
79102

80103
def test_etrecord_invalid_input(self):
81-
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
104+
captured_output, edge_output, et_output = self.get_test_model()
82105
with tempfile.TemporaryDirectory() as tmpdirname:
83106
with self.assertRaises(RuntimeError):
84107
generate_etrecord(
@@ -89,7 +112,7 @@ def test_etrecord_invalid_input(self):
89112
)
90113

91114
def test_etrecord_reserved_name(self):
92-
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
115+
captured_output, edge_output, et_output = self.get_test_model()
93116
with tempfile.TemporaryDirectory() as tmpdirname:
94117
for reserved_name in ETRecordReservedFileNames:
95118
with self.assertRaises(RuntimeError):

0 commit comments

Comments
 (0)