Skip to content

Commit fdcabe8

Browse files
tarun292facebook-github-bot
authored andcommitted
Adding support for delegate map in emitter (#334)
Summary: LoweredBackendModules might optionally have a delegate debug handle map embedded in them if `preprocess` returned a result with a map inside it. The emitter internally generates a mapping which is: ``` {'instruction_id' : {'name':delegate_name, 'debug_handle_map': delegate_identifier_map} ``` and returns this in the `EmitterOutput` object which is then subsequently copied over into the `ExecutorchProgram` object returned by `to_executorch()`. This enables us to serialize this map into etrecord which we can then use later for mapping back delegate internal operations to the original node in the graph. Reviewed By: Jack-Khuu Differential Revision: D49173918
1 parent 65189b5 commit fdcabe8

File tree

5 files changed

+111
-3
lines changed

5 files changed

+111
-3
lines changed

exir/emit/_emit_program.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
import executorch.extension.pytree as ex_pytree
1212
import torch
1313
import torch.fx
14-
from executorch.exir.emit._emitter import _EmitterState, _ProgramState, _TopLevelEmitter
14+
from executorch.exir.emit._emitter import (
15+
_DelegateDebugIdentifierMap,
16+
_EmitterState,
17+
_ProgramState,
18+
_TopLevelEmitter,
19+
)
1520
from executorch.exir.error import ExportError, ExportErrorType
1621
from executorch.exir.schema import (
1722
Bool,
@@ -108,6 +113,13 @@ class EmitterOutput:
108113
# debug handles or list of debug handles in the case of delegate calls.
109114
debug_handle_map: Dict[int, Union[int, List[int]]]
110115

116+
# This dictionary maps the method name to the corresponding dict which
117+
# contains the mapping of the delegate instruction id to its corresponding
118+
# delegate name and delegate debug identifier mapping.
119+
method_to_delegate_debug_id_map: Dict[
120+
str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]
121+
]
122+
111123

112124
def emit_program(
113125
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
@@ -145,6 +157,7 @@ def emit_program(
145157

146158
plans = []
147159
debug_handle_map = {}
160+
method_to_delegate_debug_id_map = {}
148161
program_state = _ProgramState()
149162

150163
# emit each entry point in order according to name.
@@ -177,13 +190,17 @@ def emit_program(
177190
program_state.allocated_specs
178191
)
179192
debug_handle_map[name] = emitter.debug_handle_map
193+
method_to_delegate_debug_id_map[
194+
name
195+
] = emitter.instr_id_to_delegate_debug_id_map
180196

181197
# emit any primitive getters
182198
if prim_getters is not None:
183199
plans.extend(_emit_prim_getters(prim_getters))
184200

185201
return EmitterOutput(
186202
debug_handle_map=debug_handle_map,
203+
method_to_delegate_debug_id_map=method_to_delegate_debug_id_map,
187204
program=Program(
188205
version=EXECUTORCH_SCHEMA_VERSION,
189206
execution_plan=plans,

exir/emit/_emitter.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ class _AbstractValue:
186186
None,
187187
]
188188

189+
_DelegateDebugIdentifierMap: TypeAlias = Dict[
190+
int, Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]
191+
]
189192

190193
# pyre-ignore[13]: Attribute `node` is never initialized.
191194
class _Emitter(torch.fx.Interpreter):
@@ -231,6 +234,9 @@ def __init__(
231234

232235
self.concrete_output_ids: List[_AbstractValue] = []
233236
self.debug_handle_map: Dict[int, Union[int, List[int]]] = {}
237+
self.instr_id_to_delegate_debug_id_map: Dict[
238+
int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]
239+
] = {}
234240

235241
def _stacktrace_to_framelist(self, stacktrace: str) -> FrameList:
236242
"""Creates a frame list from a stacktrace string."""
@@ -931,6 +937,26 @@ def _add_debug_handle(self, emitter_id: int, target: _Target) -> None:
931937
# the node.
932938
self.node.meta["debug_handle"] = emitter_id
933939

940+
def _add_delegate_map(
941+
self,
942+
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
943+
lowered_module: "LoweredBackendModule", # noqa
944+
delegate_instruction_id: int,
945+
) -> None:
946+
"""
947+
Store the delegate map from this lowered module into the dictionary of delegate maps. It
948+
will later be used for various debugging purposes such as linking back to original source
949+
code, module hierarchy etc.
950+
"""
951+
delegate_map = {}
952+
if hasattr(lowered_module, "meta"):
953+
delegate_map = lowered_module.meta.get("delegate_map", {})
954+
955+
self.instr_id_to_delegate_debug_id_map[delegate_instruction_id] = {
956+
"name": lowered_module.backend_id,
957+
"delegate_map": delegate_map,
958+
}
959+
934960
def _emit_argument(
935961
self, arg: _Argument, arg_type: Optional[_SchemaType]
936962
) -> _AbstractValue:
@@ -942,7 +968,6 @@ def _emit_argument(
942968

943969
def _emit_delegate(
944970
self,
945-
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
946971
lowered_module: "LoweredBackendModule", # noqa
947972
args: Tuple[_Argument, ...],
948973
kwargs: Dict[str, _Argument],
@@ -1247,7 +1272,9 @@ def call_function(
12471272
lowered_module = args[0]
12481273
assert is_lowered_module(lowered_module)
12491274
v = self._emit_delegate(lowered_module, args[1:], kwargs)
1250-
self._add_debug_handle(len(self.chain.instructions) - 1, target)
1275+
delegate_instruction_id = len(self.chain.instructions) - 1
1276+
self._add_debug_handle(delegate_instruction_id, target)
1277+
self._add_delegate_map(lowered_module, delegate_instruction_id)
12511278
return v
12521279

12531280
elif isinstance(

exir/emit/test/test_emit.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,3 +1240,49 @@ def forward(self, list_a):
12401240
.to_executorch()
12411241
)
12421242
exec_prog.buffer
1243+
1244+
def test_delegate_mapping(self) -> None:
1245+
debug_handle_map = {1: [1, 2]}
1246+
1247+
class BackendWithCompilerDemo(BackendDetails):
1248+
@staticmethod
1249+
def preprocess(
1250+
edge_program,
1251+
compile_specs,
1252+
) -> bytes:
1253+
return PreprocessResult(
1254+
processed_bytes=bytes(str("test"), encoding="utf8"),
1255+
debug_handle_map=debug_handle_map,
1256+
)
1257+
1258+
class TestModel(nn.Module):
1259+
def __init__(self):
1260+
super(TestModel, self).__init__()
1261+
1262+
def forward(self, x, y):
1263+
return torch.add(x, y)
1264+
1265+
inputs = (torch.ones(2, 2), torch.ones(2, 2))
1266+
model = TestModel()
1267+
edgeir_m = exir.capture(model, inputs, exir.CaptureConfig()).to_edge(
1268+
exir.EdgeCompileConfig(_check_ir_validity=False)
1269+
)
1270+
lowered_module = to_backend(
1271+
"BackendWithCompilerDemo", edgeir_m.exported_program, None
1272+
)
1273+
1274+
class CompositeModule(torch.nn.Module):
1275+
def __init__(self):
1276+
super().__init__()
1277+
self.lowered_module = lowered_module
1278+
1279+
def forward(self, x, y):
1280+
return self.lowered_module(x, y)
1281+
1282+
composite_model = CompositeModule()
1283+
exec_prog = (
1284+
exir.capture(composite_model, inputs, exir.CaptureConfig())
1285+
.to_edge()
1286+
.to_executorch()
1287+
)
1288+
self.assertIsNotNone(exec_prog.delegate_map)

exir/program/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ python_library(
2222
"//executorch/exir:schema",
2323
"//executorch/exir/_serialize:lib",
2424
"//executorch/exir/capture:config",
25+
"//executorch/exir/emit:emit",
2526
"//executorch/exir/emit:lib",
2627
"//executorch/exir/passes:lib",
2728
"//executorch/exir/passes:remove_assert_async_pass",

exir/program/_program.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.exir._serialize import _serialize_pte_binary
1414
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
1515
from executorch.exir.emit import emit_program, EmitterOutput
16+
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
1617
from executorch.exir.error import ExportError
1718
from executorch.exir.pass_manager import PassType
1819
from executorch.exir.passes import (
@@ -251,6 +252,14 @@ def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
251252
return self._emitter_output.debug_handle_map
252253
return {}
253254

255+
@property
256+
def delegate_map(
257+
self,
258+
) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
259+
if self._emitter_output:
260+
return self._emitter_output.method_to_delegate_debug_id_map
261+
return {}
262+
254263
@property
255264
def graph_module(self) -> torch.fx.GraphModule:
256265
return self.exported_program.graph_module
@@ -498,6 +507,14 @@ def program(self) -> Program:
498507
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
499508
return self._emitter_output.debug_handle_map
500509

510+
@property
511+
def delegate_map(
512+
self,
513+
) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
514+
if self._emitter_output:
515+
return self._emitter_output.method_to_delegate_debug_id_map
516+
return {}
517+
501518
# TODO(ycao): This doesn't make sense any more, remove/change later.
502519
def dump_graph_module(self) -> torch.fx.GraphModule:
503520
return self.get_multi_method_graph_module().module

0 commit comments

Comments
 (0)