Skip to content

Commit 516c291

Browse files
tarun292facebook-github-bot
authored andcommitted
Adding support for delegate map in emitter (#334)
Summary: LoweredBackendModules might optionally have a delegate debug handle map embedded in them if `preprocess` returned a result with a map inside it. The emitter internally generates a mapping which is: ``` {'instruction_id' : {'name':delegate_name, 'debug_handle_map': delegate_identifier_map} ``` and returns this in the `EmitterOutput` object which is then subsequently copied over into the `ExecutorchProgram` object returned by `to_executorch()`. This enables us to serialize this map into etrecord which we can then use later for mapping back delegate internal operations to the original node in the graph. Differential Revision: D49173918
1 parent 4f3e5e6 commit 516c291

File tree

5 files changed

+108
-3
lines changed

5 files changed

+108
-3
lines changed

exir/emit/_emit_program.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
import executorch.extension.pytree as ex_pytree
1212
import torch
1313
import torch.fx
14-
from executorch.exir.emit._emitter import _EmitterState, _ProgramState, _TopLevelEmitter
14+
from executorch.exir.emit._emitter import (
15+
_DelegateDebugIdentifierMap,
16+
_EmitterState,
17+
_ProgramState,
18+
_TopLevelEmitter,
19+
)
1520
from executorch.exir.error import ExportError, ExportErrorType
1621
from executorch.exir.schema import (
1722
Bool,
@@ -108,6 +113,13 @@ class EmitterOutput:
108113
# debug handles or list of debug handles in the case of delegate calls.
109114
debug_handle_map: Dict[int, Union[int, List[int]]]
110115

116+
# This dictionary maps the method name to the corresponding dict which
117+
# contains the mapping of the delegate instruction id to its corresponding
118+
# delegate name and delegate debug identifier mapping.
119+
method_to_delegate_debug_id_map: Dict[
120+
str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]
121+
]
122+
111123

112124
def emit_program(
113125
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
@@ -145,6 +157,7 @@ def emit_program(
145157

146158
plans = []
147159
debug_handle_map = {}
160+
method_to_delegate_debug_id_map = {}
148161
program_state = _ProgramState()
149162

150163
# emit each entry point in order according to name.
@@ -177,13 +190,17 @@ def emit_program(
177190
program_state.allocated_specs
178191
)
179192
debug_handle_map[name] = emitter.debug_handle_map
193+
method_to_delegate_debug_id_map[
194+
name
195+
] = emitter.instr_id_to_delegate_debug_id_map
180196

181197
# emit any primitive getters
182198
if prim_getters is not None:
183199
plans.extend(_emit_prim_getters(prim_getters))
184200

185201
return EmitterOutput(
186202
debug_handle_map=debug_handle_map,
203+
method_to_delegate_debug_id_map=method_to_delegate_debug_id_map,
187204
program=Program(
188205
version=EXECUTORCH_SCHEMA_VERSION,
189206
execution_plan=plans,

exir/emit/_emitter.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ class _AbstractValue:
186186
None,
187187
]
188188

189+
_DelegateDebugIdentifierMap: TypeAlias = Dict[
190+
int, Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]
191+
]
189192

190193
# pyre-ignore[13]: Attribute `node` is never initialized.
191194
class _Emitter(torch.fx.Interpreter):
@@ -231,6 +234,9 @@ def __init__(
231234

232235
self.concrete_output_ids: List[_AbstractValue] = []
233236
self.debug_handle_map: Dict[int, Union[int, List[int]]] = {}
237+
self.instr_id_to_delegate_debug_id_map: Dict[
238+
int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]
239+
] = {}
234240

235241
def _stacktrace_to_framelist(self, stacktrace: str) -> FrameList:
236242
"""Creates a frame list from a stacktrace string."""
@@ -931,6 +937,22 @@ def _add_debug_handle(self, emitter_id: int, target: _Target) -> None:
931937
# the node.
932938
self.node.meta["debug_handle"] = emitter_id
933939

940+
def _add_delegate_map(
941+
self,
942+
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
943+
lowered_module: "LoweredBackendModule", # noqa
944+
delegate_instruction_id: int,
945+
) -> None:
946+
"""
947+
Store the delegate map from this lowered module into the dictionary of delegate maps. It
948+
will later be used for various debugging purposes such as linking back to original source
949+
code, module hierarchy etc.
950+
"""
951+
self.instr_id_to_delegate_debug_id_map[delegate_instruction_id] = {
952+
"name": lowered_module.backend_id,
953+
"delegate_map": lowered_module.meta.get("debug_handle_map"),
954+
}
955+
934956
def _emit_argument(
935957
self, arg: _Argument, arg_type: Optional[_SchemaType]
936958
) -> _AbstractValue:
@@ -942,7 +964,6 @@ def _emit_argument(
942964

943965
def _emit_delegate(
944966
self,
945-
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
946967
lowered_module: "LoweredBackendModule", # noqa
947968
args: Tuple[_Argument, ...],
948969
kwargs: Dict[str, _Argument],
@@ -1247,7 +1268,9 @@ def call_function(
12471268
lowered_module = args[0]
12481269
assert is_lowered_module(lowered_module)
12491270
v = self._emit_delegate(lowered_module, args[1:], kwargs)
1250-
self._add_debug_handle(len(self.chain.instructions) - 1, target)
1271+
delegate_instruction_id = len(self.chain.instructions) - 1
1272+
self._add_debug_handle(delegate_instruction_id, target)
1273+
self._add_delegate_map(lowered_module, delegate_instruction_id)
12511274
return v
12521275

12531276
elif isinstance(

exir/emit/test/test_emit.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,7 @@ def forward(self, x):
12241224
lowered_module = to_backend(
12251225
"BackendWithCompilerDemo", edgeir_m.exported_program, None
12261226
)
1227+
lowered_module.program()
12271228

12281229
class CompositeModule(torch.nn.Module):
12291230
def __init__(self):
@@ -1240,3 +1241,49 @@ def forward(self, list_a):
12401241
.to_executorch()
12411242
)
12421243
exec_prog.buffer
1244+
1245+
def test_delegate_mapping(self) -> None:
1246+
debug_handle_map = {1: [1, 2]}
1247+
1248+
class BackendWithCompilerDemo(BackendDetails):
1249+
@staticmethod
1250+
def preprocess(
1251+
edge_program,
1252+
compile_specs,
1253+
) -> bytes:
1254+
return PreprocessResult(
1255+
processed_bytes=bytes(str("test"), encoding="utf8"),
1256+
debug_handle_map=debug_handle_map,
1257+
)
1258+
1259+
class TestModel(nn.Module):
1260+
def __init__(self):
1261+
super(TestModel, self).__init__()
1262+
1263+
def forward(self, x, y):
1264+
return torch.add(x, y)
1265+
1266+
inputs = (torch.ones(2, 2), torch.ones(2, 2))
1267+
model = TestModel()
1268+
edgeir_m = exir.capture(model, inputs, exir.CaptureConfig()).to_edge(
1269+
exir.EdgeCompileConfig(_check_ir_validity=False)
1270+
)
1271+
lowered_module = to_backend(
1272+
"BackendWithCompilerDemo", edgeir_m.exported_program, None
1273+
)
1274+
1275+
class CompositeModule(torch.nn.Module):
1276+
def __init__(self):
1277+
super().__init__()
1278+
self.lowered_module = lowered_module
1279+
1280+
def forward(self, x, y):
1281+
return self.lowered_module(x, y)
1282+
1283+
composite_model = CompositeModule()
1284+
exec_prog = (
1285+
exir.capture(composite_model, inputs, exir.CaptureConfig())
1286+
.to_edge()
1287+
.to_executorch()
1288+
)
1289+
self.assertIsNotNone(exec_prog.delegate_map)

exir/program/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ python_library(
2222
"//executorch/exir:schema",
2323
"//executorch/exir/_serialize:lib",
2424
"//executorch/exir/capture:config",
25+
"//executorch/exir/emit:emit",
2526
"//executorch/exir/emit:lib",
2627
"//executorch/exir/passes:lib",
2728
"//executorch/exir/passes:remove_assert_async_pass",

exir/program/_program.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.exir._serialize import _serialize_pte_binary
1414
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
1515
from executorch.exir.emit import emit_program, EmitterOutput
16+
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
1617
from executorch.exir.error import ExportError
1718
from executorch.exir.pass_manager import PassType
1819
from executorch.exir.passes import (
@@ -251,6 +252,14 @@ def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
251252
return self._emitter_output.debug_handle_map
252253
return {}
253254

255+
@property
256+
def delegate_map(
257+
self,
258+
) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
259+
if self._emitter_output:
260+
return self._emitter_output.method_to_delegate_debug_id_map
261+
return {}
262+
254263
@property
255264
def graph_module(self) -> torch.fx.GraphModule:
256265
return self.exported_program.graph_module
@@ -498,6 +507,14 @@ def program(self) -> Program:
498507
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
499508
return self._emitter_output.debug_handle_map
500509

510+
@property
511+
def delegate_map(
512+
self,
513+
) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]:
514+
if self._emitter_output:
515+
return self._emitter_output.method_to_delegate_debug_id_map
516+
return {}
517+
501518
# TODO(ycao): This doesn't make sense any more, remove/change later.
502519
def dump_graph_module(self) -> torch.fx.GraphModule:
503520
return self.get_multi_method_graph_module().module

0 commit comments

Comments
 (0)