Skip to content

Add support for EdgeProgramManager and ExecutorchProgramManager in etrecord #788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/sdk-etrecord.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ will be expected to provide the Edge Dialect program (returned by the call to ``
the ExecuTorch program (returned by the call to ``to_executorch()``), and optional models that
they are interested in working with via our tooling.

.. warning::
Users should do a deepcopy of the output of to_edge() and pass in the deepcopy to the generate_etrecord API. This is needed because the subsequent call, to_executorch(), does an in-place mutation and will lose debug data in the process.

.. currentmodule:: sdk.etrecord._etrecord
.. autofunction:: generate_etrecord

Expand Down
7 changes: 6 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,9 +878,14 @@ def dump_executorch_program(self, verbose: bool = False) -> None:

@property
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
# TODO ask Tarun what the docstring here should be.
return self._emitter_output.debug_handle_map

@property
def delegate_map(
self,
) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
return self._emitter_output.method_to_delegate_debug_id_map

@property
def executorch_program(self) -> Program:
"""
Expand Down
93 changes: 44 additions & 49 deletions sdk/etrecord/_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from typing import Dict, List, Optional, Union
from zipfile import BadZipFile, ZipFile

from executorch import exir
from executorch.exir import (
EdgeProgramManager,
ExecutorchProgram,
ExecutorchProgramManager,
ExirExportedProgram,
ExportedProgram,
MultiMethodExecutorchProgram,
Expand Down Expand Up @@ -65,7 +68,9 @@ def _handle_multi_method_exported_program(

def _handle_export_module(
etrecord_zip: ZipFile,
export_module: Union[MultiMethodExirExportedProgram, ExirExportedProgram],
export_module: Union[
MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager
],
module_name: str,
) -> None:
if isinstance(export_module, MultiMethodExirExportedProgram):
Expand All @@ -74,45 +79,21 @@ def _handle_export_module(
_handle_exported_program(
etrecord_zip, module_name, "forward", export_module.exported_program
)
elif isinstance(
export_module,
(EdgeProgramManager, exir.program._program.EdgeProgramManager),
):
for method in export_module.methods:
_handle_exported_program(
etrecord_zip,
module_name,
method,
export_module.exported_program(method),
)
else:
raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")


def _handle_executorch_program(
etrecord_zip: ZipFile,
program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
) -> None:
if isinstance(program, MultiMethodExecutorchProgram):
# Do a dummy read of the program here to make sure that the emitter runs
# under the hood which will result in the debug handle map being generated.
program.program

_handle_multi_method_exported_program(
etrecord_zip,
ETRecordReservedFileNames.ET_DIALECT_GRAPH_MODULE,
program._executorch_dialect_ir_program,
)

elif isinstance(program, ExecutorchProgram):
# Do a dummy read of the program here to make sure that the emitter runs
# under the hood which will result in the debug handle map being generated.
program.program

_handle_exported_program(
etrecord_zip,
ETRecordReservedFileNames.ET_DIALECT_GRAPH_MODULE,
"forward",
program.dump_exported_program(),
)

etrecord_zip.writestr(ETRecordReservedFileNames.PROGRAM_BUFFER, program.buffer)

else:
raise RuntimeError(
f"program passed in should be either ExecutorchProgram or MultiMethodExecutorchProgram. {type(program)}"
)


def _handle_edge_dialect_exported_program(
etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram
) -> None:
Expand All @@ -130,12 +111,16 @@ def _handle_edge_dialect_exported_program(

def generate_etrecord(
etrecord_path: str,
edge_dialect_program: ExirExportedProgram,
executorch_program: Union[ExecutorchProgram, MultiMethodExecutorchProgram],
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
executorch_program: Union[
ExecutorchProgram, MultiMethodExecutorchProgram, ExecutorchProgramManager
],
export_modules: Optional[
Dict[
str,
Union[MultiMethodExirExportedProgram, ExirExportedProgram],
Union[
MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager
],
]
] = None,
) -> None:
Expand All @@ -151,10 +136,9 @@ def generate_etrecord(

Args:
etrecord_path: Path to where the `ETRecord` file will be saved to.
edge_dialect_program: `ExirExportedProgram` for this model returned by the call to to_edge()
executorch_program: `ExecutorchProgram` or `MultiMethodExecutorchProgram` for this model returned by the
call to `to_executorch()`
export_modules: A dictionary of graph modules with the key being the user provided name and the
edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge()
executorch_program: `ExecutorchProgramManager` for this model returned by the call to `to_executorch()`
export_modules[Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
value being the corresponding exported module. The exported graph modules can be either the
output of `capture()` or `to_edge()`.

Expand All @@ -179,12 +163,23 @@ def generate_etrecord(
)
_handle_export_module(etrecord_zip, export_module, module_name)

_handle_executorch_program(etrecord_zip, executorch_program)

_handle_edge_dialect_exported_program(
etrecord_zip,
edge_dialect_program.exported_program,
)
if isinstance(
edge_dialect_program,
(EdgeProgramManager, exir.program._program.EdgeProgramManager),
):
_handle_edge_dialect_exported_program(
etrecord_zip,
edge_dialect_program.exported_program(),
)
elif isinstance(edge_dialect_program, ExirExportedProgram):
_handle_edge_dialect_exported_program(
etrecord_zip,
edge_dialect_program.exported_program,
)
else:
raise RuntimeError(
f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}."
)

etrecord_zip.writestr(
ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,
Expand Down
2 changes: 2 additions & 0 deletions sdk/etrecord/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ python_unittest(
name = "etrecord_test",
srcs = ["etrecord_test.py"],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/tests:models",
"//executorch/sdk/etrecord:etrecord",
Expand All @@ -17,6 +18,7 @@ python_library(
name = "etrecord_test_library",
srcs = ["etrecord_test.py"],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/tests:models",
"//executorch/sdk/etrecord:etrecord",
Expand Down
40 changes: 32 additions & 8 deletions sdk/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

import executorch.exir.tests.models as models
from executorch import exir
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from executorch.sdk.etrecord import generate_etrecord, parse_etrecord
from executorch.sdk.etrecord._etrecord import ETRecordReservedFileNames
from torch.export import export


# TODO : T154728484 Add test cases to cover multiple entry points
Expand All @@ -27,8 +29,16 @@ def get_test_model(self):
)
edge_output_copy = copy.deepcopy(edge_output)
et_output = edge_output.to_executorch()
buffer = et_output.buffer
return (captured_output_copy, edge_output_copy, et_output, buffer)
return (captured_output_copy, edge_output_copy, et_output)

def get_test_model_with_manager(self):
f = models.BasicSinMax()
aten_dialect = export(f, f.get_random_inputs())
edge_program: EdgeProgramManager = to_edge(
aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)
)
edge_program_copy = copy.deepcopy(edge_program)
return (aten_dialect, edge_program_copy, edge_program.to_executorch())

# Serialized and deserialized graph modules are not completely the same, so we check
# that they are close enough and match especially on the parameters we care about in the SDK.
Expand All @@ -47,7 +57,7 @@ def check_graph_closeness(self, graph_a, graph_b):
)

def test_etrecord_generation(self):
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
captured_output, edge_output, et_output = self.get_test_model()
with tempfile.TemporaryDirectory() as tmpdirname:
generate_etrecord(
tmpdirname + "/etrecord.bin",
Expand All @@ -67,18 +77,32 @@ def test_etrecord_generation(self):
etrecord.edge_dialect_program,
edge_output.exported_program.graph_module,
)
self.assertEqual(
etrecord._debug_handle_map,
json.loads(json.dumps(et_output.debug_handle_map)),
)

def test_etrecord_generation_with_manager(self):
captured_output, edge_output, et_output = self.get_test_model_with_manager()
with tempfile.TemporaryDirectory() as tmpdirname:
generate_etrecord(
tmpdirname + "/etrecord.bin",
edge_output,
et_output,
)

etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
self.check_graph_closeness(
etrecord.graph_map["et_dialect_graph_module/forward"],
et_output.dump_exported_program(),
etrecord.edge_dialect_program,
edge_output.exported_program().graph_module,
)
self.assertEqual(
etrecord._debug_handle_map,
json.loads(json.dumps(et_output.debug_handle_map)),
)
self.assertEqual(etrecord.program_buffer, program_buffer)

def test_etrecord_invalid_input(self):
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
captured_output, edge_output, et_output = self.get_test_model()
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertRaises(RuntimeError):
generate_etrecord(
Expand All @@ -89,7 +113,7 @@ def test_etrecord_invalid_input(self):
)

def test_etrecord_reserved_name(self):
captured_output, edge_output, et_output, program_buffer = self.get_test_model()
captured_output, edge_output, et_output = self.get_test_model()
with tempfile.TemporaryDirectory() as tmpdirname:
for reserved_name in ETRecordReservedFileNames:
with self.assertRaises(RuntimeError):
Expand Down
2 changes: 1 addition & 1 deletion sdk/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_inspector_get_exported_program(self):
)

# Gen a mock etrecord
captured_output, edge_output, et_output, _ = TestETRecord().get_test_model()
captured_output, edge_output, et_output = TestETRecord().get_test_model()
with tempfile.TemporaryDirectory() as tmpdirname:
generate_etrecord(
tmpdirname + "/etrecord.bin",
Expand Down
6 changes: 1 addition & 5 deletions sdk/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class TestInspectorUtils(unittest.TestCase):
def test_gen_graphs_from_etrecord(self):
captured_output, edge_output, et_output, _ = TestETRecord().get_test_model()
captured_output, edge_output, et_output = TestETRecord().get_test_model()
with tempfile.TemporaryDirectory() as tmpdirname:
generate_etrecord(
tmpdirname + "/etrecord.bin",
Expand All @@ -43,15 +43,11 @@ def test_gen_graphs_from_etrecord(self):
graphs = gen_graphs_from_etrecord(etrecord)

self.assertTrue("aten_dialect_output/forward" in graphs)
self.assertTrue("et_dialect_graph_module/forward" in graphs)
self.assertTrue(EDGE_DIALECT_GRAPH_KEY in graphs)

self.assertTrue(
isinstance(graphs["aten_dialect_output/forward"], FXOperatorGraph)
)
self.assertTrue(
isinstance(graphs["et_dialect_graph_module/forward"], FXOperatorGraph)
)
self.assertTrue(isinstance(graphs[EDGE_DIALECT_GRAPH_KEY], FXOperatorGraph))

def test_create_debug_handle_to_op_node_mapping(self):
Expand Down