Skip to content

Adding support for delegate map in emitter #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
import executorch.extension.pytree as ex_pytree
import torch
import torch.fx
from executorch.exir.emit._emitter import _EmitterState, _ProgramState, _TopLevelEmitter
from executorch.exir.emit._emitter import (
_DelegateDebugIdentifierMap,
_EmitterState,
_ProgramState,
_TopLevelEmitter,
)
from executorch.exir.error import ExportError, ExportErrorType
from executorch.exir.schema import (
Bool,
Expand Down Expand Up @@ -108,6 +113,13 @@ class EmitterOutput:
# debug handles or list of debug handles in the case of delegate calls.
debug_handle_map: Dict[int, Union[int, List[int]]]

# This dictionary maps the method name to the corresponding dict which
# contains the mapping of the delegate instruction id to its corresponding
# delegate name and delegate debug identifier mapping.
method_to_delegate_debug_id_map: Dict[
str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]
]


def emit_program(
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
Expand Down Expand Up @@ -145,6 +157,7 @@ def emit_program(

plans = []
debug_handle_map = {}
method_to_delegate_debug_id_map = {}
program_state = _ProgramState()

# emit each entry point in order according to name.
Expand Down Expand Up @@ -177,13 +190,17 @@ def emit_program(
program_state.allocated_specs
)
debug_handle_map[name] = emitter.debug_handle_map
method_to_delegate_debug_id_map[
name
] = emitter.instr_id_to_delegate_debug_id_map

# emit any primitive getters
if prim_getters is not None:
plans.extend(_emit_prim_getters(prim_getters))

return EmitterOutput(
debug_handle_map=debug_handle_map,
method_to_delegate_debug_id_map=method_to_delegate_debug_id_map,
program=Program(
version=EXECUTORCH_SCHEMA_VERSION,
execution_plan=plans,
Expand Down
31 changes: 29 additions & 2 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ class _AbstractValue:
None,
]

_DelegateDebugIdentifierMap: TypeAlias = Union[
Dict[int, Tuple[int]], Dict[str, Tuple[int]]
]

# pyre-ignore[13]: Attribute `node` is never initialized.
class _Emitter(torch.fx.Interpreter):
Expand Down Expand Up @@ -231,6 +234,9 @@ def __init__(

self.concrete_output_ids: List[_AbstractValue] = []
self.debug_handle_map: Dict[int, Union[int, List[int]]] = {}
self.instr_id_to_delegate_debug_id_map: Dict[
int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]
] = {}

def _stacktrace_to_framelist(self, stacktrace: str) -> FrameList:
"""Creates a frame list from a stacktrace string."""
Expand Down Expand Up @@ -931,6 +937,26 @@ def _add_debug_handle(self, emitter_id: int, target: _Target) -> None:
# the node.
self.node.meta["debug_handle"] = emitter_id

def _add_delegate_map(
self,
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
lowered_module: "LoweredBackendModule", # noqa
delegate_instruction_id: int,
) -> None:
"""
Store the delegate map from this lowered module into the dictionary of delegate maps. It
will later be used for various debugging purposes such as linking back to original source
code, module hierarchy etc.
"""
delegate_map = {}
if hasattr(lowered_module, "meta"):
delegate_map = lowered_module.meta.get("delegate_map", {})

self.instr_id_to_delegate_debug_id_map[delegate_instruction_id] = {
"name": lowered_module.backend_id,
"delegate_map": delegate_map,
}

def _emit_argument(
self, arg: _Argument, arg_type: Optional[_SchemaType]
) -> _AbstractValue:
Expand All @@ -942,7 +968,6 @@ def _emit_argument(

def _emit_delegate(
self,
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
lowered_module: "LoweredBackendModule", # noqa
args: Tuple[_Argument, ...],
kwargs: Dict[str, _Argument],
Expand Down Expand Up @@ -1247,7 +1272,9 @@ def call_function(
lowered_module = args[0]
assert is_lowered_module(lowered_module)
v = self._emit_delegate(lowered_module, args[1:], kwargs)
self._add_debug_handle(len(self.chain.instructions) - 1, target)
delegate_instruction_id = len(self.chain.instructions) - 1
self._add_debug_handle(delegate_instruction_id, target)
self._add_delegate_map(lowered_module, delegate_instruction_id)
return v

elif isinstance(
Expand Down
46 changes: 46 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,3 +1240,49 @@ def forward(self, list_a):
.to_executorch()
)
exec_prog.buffer

def test_delegate_mapping(self) -> None:
debug_handle_map = {1: [1, 2]}

class BackendWithCompilerDemo(BackendDetails):
@staticmethod
def preprocess(
edge_program,
compile_specs,
) -> bytes:
return PreprocessResult(
processed_bytes=bytes(str("test"), encoding="utf8"),
debug_handle_map=debug_handle_map,
)

class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()

def forward(self, x, y):
return torch.add(x, y)

inputs = (torch.ones(2, 2), torch.ones(2, 2))
model = TestModel()
edgeir_m = exir.capture(model, inputs, exir.CaptureConfig()).to_edge(
exir.EdgeCompileConfig(_check_ir_validity=False)
)
lowered_module = to_backend(
"BackendWithCompilerDemo", edgeir_m.exported_program, None
)

class CompositeModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lowered_module = lowered_module

def forward(self, x, y):
return self.lowered_module(x, y)

composite_model = CompositeModule()
exec_prog = (
exir.capture(composite_model, inputs, exir.CaptureConfig())
.to_edge()
.to_executorch()
)
self.assertIsNotNone(exec_prog.delegate_map)
1 change: 1 addition & 0 deletions exir/program/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ python_library(
"//executorch/exir:schema",
"//executorch/exir/_serialize:lib",
"//executorch/exir/capture:config",
"//executorch/exir/emit:emit",
"//executorch/exir/emit:lib",
"//executorch/exir/passes:lib",
"//executorch/exir/passes:remove_assert_async_pass",
Expand Down
17 changes: 17 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.emit import emit_program, EmitterOutput
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
from executorch.exir.error import ExportError
from executorch.exir.pass_manager import PassType
from executorch.exir.passes import (
Expand Down Expand Up @@ -251,6 +252,14 @@ def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
return self._emitter_output.debug_handle_map
return {}

@property
def delegate_map(
self,
) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
if self._emitter_output:
return self._emitter_output.method_to_delegate_debug_id_map
return {}

@property
def graph_module(self) -> torch.fx.GraphModule:
return self.exported_program.graph_module
Expand Down Expand Up @@ -498,6 +507,14 @@ def program(self) -> Program:
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
return self._emitter_output.debug_handle_map

@property
def delegate_map(
self,
) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
if self._emitter_output:
return self._emitter_output.method_to_delegate_debug_id_map
return {}

# TODO(ycao): This doesn't make sense any more, remove/change later.
def dump_graph_module(self) -> torch.fx.GraphModule:
return self.get_multi_method_graph_module().module
Expand Down