Skip to content

Commit 2ae19c4

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
Populate Event attributes with op nodes metadata linked by debug handles (#401)
Summary: The main logic change is in the __init__() of Inspector. Differential Revision: D49326221
1 parent 49d2e68 commit 2ae19c4

File tree

7 files changed

+441
-48
lines changed

7 files changed

+441
-48
lines changed

sdk/etdb/TARGETS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ python_library(
4141
"//executorch/exir:lib",
4242
"//executorch/sdk/edir:et_schema",
4343
"//executorch/sdk/etdump:schema_flatcc",
44-
"//executorch/sdk/etrecord:etrecord",
4544
],
4645
)
4746

@@ -52,6 +51,8 @@ python_library(
5251
],
5352
deps = [
5453
"//executorch/sdk/edir:et_schema",
54+
"//executorch/sdk/etdump:schema_flatcc",
55+
"//executorch/sdk/etdump:serialize",
5556
"//executorch/sdk/etrecord:etrecord",
5657
],
5758
)

sdk/etdb/_inspector_utils.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Mapping
7+
from typing import Dict, Mapping, Optional, Tuple
88

9-
from executorch.sdk.edir.et_schema import FXOperatorGraph, OperatorGraphWithStats
10-
from executorch.sdk.etrecord import ETRecord
9+
from executorch.sdk.edir.et_schema import (
10+
FXOperatorGraph,
11+
OperatorGraphWithStats,
12+
OperatorNode,
13+
)
14+
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC
15+
16+
from executorch.sdk.etdump.serialize import deserialize_from_etdump_flatcc
17+
from executorch.sdk.etrecord import ETRecord, parse_etrecord
1118

1219

13-
# TODO: add a unittest for this function
1420
def gen_graphs_from_etrecord(
1521
etrecord: ETRecord,
1622
) -> Mapping[str, OperatorGraphWithStats]:
@@ -20,3 +26,44 @@ def gen_graphs_from_etrecord(
2026
name: FXOperatorGraph.gen_operator_graph(exported_program.graph_module)
2127
for name, exported_program in etrecord.graph_map.items()
2228
}
29+
30+
31+
def create_debug_handle_to_op_node_mapping(
32+
op_graph: OperatorGraphWithStats,
33+
debug_handle_to_op_node_map: Dict[int, OperatorNode],
34+
) -> None:
35+
"""
36+
Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
37+
from each debug handle to the operator node that contains the debug handle in its metadata.
38+
"""
39+
# Recursively searches through the metadata of nodes
40+
for element in op_graph.elements:
41+
if isinstance(element, OperatorGraphWithStats):
42+
create_debug_handle_to_op_node_mapping(element, debug_handle_to_op_node_map)
43+
if isinstance(element, OperatorNode) and element.metadata is not None:
44+
metadata = element.metadata
45+
debug_handle = metadata.get("debug_handle")
46+
if debug_handle is not None:
47+
existing_entry = debug_handle_to_op_node_map.get(debug_handle)
48+
if existing_entry is not None:
49+
raise ValueError(
50+
f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
51+
"No two op nodes of the same graph should have the same debug handle."
52+
)
53+
debug_handle_to_op_node_map[debug_handle] = element
54+
55+
56+
def gen_etdump_and_etrecord_objects(
57+
etdump_path: Optional[str] = None, etrecord_path: Optional[str] = None
58+
) -> Tuple[ETDumpFlatCC, ETRecord]:
59+
# Gen event blocks from etdump
60+
if etdump_path is None:
61+
raise ValueError("Etdump_path must be specified.")
62+
with open(etdump_path, "rb") as buff:
63+
etdump = deserialize_from_etdump_flatcc(buff.read())
64+
# Gen op graphs from etrecord
65+
if etrecord_path is None:
66+
raise ValueError("Etrecord_path must be specified.")
67+
etrecord = parse_etrecord(etrecord_path=etrecord_path)
68+
69+
return (etdump, etrecord)

sdk/etdb/inspector.py

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525
import torch
2626
from executorch.exir import ExportedProgram
2727

28-
from executorch.sdk.edir.et_schema import OperatorGraphWithStats
29-
from executorch.sdk.etdb._inspector_utils import gen_graphs_from_etrecord
28+
from executorch.sdk.edir.et_schema import OperatorGraphWithStats, OperatorNode
29+
from executorch.sdk.etdb._inspector_utils import (
30+
create_debug_handle_to_op_node_mapping,
31+
gen_etdump_and_etrecord_objects,
32+
gen_graphs_from_etrecord,
33+
)
3034
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC, ProfileEvent
31-
from executorch.sdk.etrecord import parse_etrecord
35+
3236
from tabulate import tabulate
3337

3438
log: logging.Logger = logging.getLogger(__name__)
@@ -112,7 +116,7 @@ class Event:
112116

113117
name: str
114118
perf_data: PerfData
115-
op_type: List[str] = dataclasses.field(default_factory=list)
119+
op_types: List[str] = dataclasses.field(default_factory=list)
116120

117121
# Instruction Id of the original profiling event
118122
instruction_id: Optional[int] = None
@@ -123,7 +127,7 @@ class Event:
123127
# Debug Handles in the model graph to which this event is correlated
124128
debug_handles: Optional[Union[int, Sequence[int]]] = None
125129

126-
stack_trace: Dict[str, str] = dataclasses.field(default_factory=dict)
130+
stack_traces: Dict[str, str] = dataclasses.field(default_factory=dict)
127131
module_hierarchy: Dict[str, Dict] = dataclasses.field(default_factory=dict)
128132
is_delegated_op: Optional[bool] = None
129133
delegate_backend_name: Optional[str] = None
@@ -158,6 +162,33 @@ def _gen_from_profile_events(
158162
is_delegated_op=is_delegated_op,
159163
)
160164

165+
def _associate_with_op_graph_nodes(
166+
self, debug_handle_to_op_node_map: Dict[int, OperatorNode]
167+
) -> None:
168+
"""
169+
Helper function to populate the stack_traces, module_hierarchy and op_types attributes
170+
based on the debug handles of this event
171+
"""
172+
debug_handles = []
173+
if self.debug_handles is None:
174+
return
175+
176+
if isinstance(self.debug_handles, int):
177+
debug_handles = [self.debug_handles]
178+
elif isinstance(self.debug_handles, Sequence):
179+
debug_handles = self.debug_handles
180+
181+
for handle in debug_handles:
182+
node = debug_handle_to_op_node_map.get(handle)
183+
if node is not None and node.metadata is not None:
184+
self.stack_traces[node.name] = node.metadata.get("stack_trace")
185+
if node.metadata:
186+
self.module_hierarchy[node.name] = node.metadata.get(
187+
"nn_module_stack"
188+
)
189+
if node.op:
190+
self.op_types += [node.op]
191+
161192

162193
@dataclass
163194
class EventBlock:
@@ -186,11 +217,11 @@ def to_dataframe(self) -> pd.DataFrame:
186217
"min": [event.perf_data.min for event in self.events],
187218
"max": [event.perf_data.max for event in self.events],
188219
"median": [event.perf_data.median for event in self.events],
189-
"op_type": [event.op_type for event in self.events],
220+
"op_types": [event.op_types for event in self.events],
190221
"delegate_debug_identifier": [
191222
event.delegate_debug_identifier for event in self.events
192223
],
193-
"stack_traces": [event.stack_trace for event in self.events],
224+
"stack_traces": [event.stack_traces for event in self.events],
194225
"module_hierarchy": [event.module_hierarchy for event in self.events],
195226
"is_delegated_op": [event.is_delegated_op for event in self.events],
196227
"delegate_backend_name": [
@@ -290,6 +321,9 @@ def _gen_resolve_debug_handles(
290321
)
291322

292323

324+
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_output/forward"
325+
326+
293327
class Inspector:
294328
"""
295329
APIs for examining model architecture and performance stats
@@ -301,16 +335,33 @@ def __init__(
301335
"""
302336
Create an inspector instance from the provided ETDump/ETRecord
303337
"""
338+
etdump, self._etrecord = gen_etdump_and_etrecord_objects(
339+
etdump_path=etdump_path, etrecord_path=etrecord_path
340+
)
341+
342+
self.event_blocks = EventBlock._gen_from_etdump(etdump)
304343

305-
# Gen op graphs from etrecord
306-
if etrecord_path is not None:
307-
self._etrecord = parse_etrecord(etrecord_path=etrecord_path)
308-
self._op_graph_dict: Mapping[
309-
str, OperatorGraphWithStats
310-
] = gen_graphs_from_etrecord(etrecord=self._etrecord)
344+
self._op_graph_dict: Mapping[
345+
str, OperatorGraphWithStats
346+
] = gen_graphs_from_etrecord(etrecord=self._etrecord)
311347

312-
self.event_blocks: List[EventBlock] = []
313-
# TODO: create event blocks from etdump, and associate events with op graph nodes
348+
# Use the delegate map from etrecord, associate debug handles with each event
349+
for event_block in self.event_blocks:
350+
event_block._gen_resolve_debug_handles(
351+
self._etrecord._debug_handle_map.get("forward"),
352+
self._etrecord._delegate_map.get("forward"),
353+
)
354+
355+
# Traverse the edge dialect op graph to create mapping from debug_handle to op node
356+
debug_handle_to_op_node_map = {}
357+
create_debug_handle_to_op_node_mapping(
358+
self._op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
359+
debug_handle_to_op_node_map,
360+
)
361+
362+
for event_block in self.event_blocks:
363+
for event in event_block.events:
364+
event._associate_with_op_graph_nodes(debug_handle_to_op_node_map)
314365

315366
def print_data_tabular(self) -> None:
316367
"""
@@ -353,12 +404,13 @@ def write_tensorboard_artifact(self, path: str) -> None:
353404
# TODO: implement
354405
pass
355406

356-
# TODO: add a unittest for this function
357-
def get_exported_program(self, graph: Optional[str]) -> ExportedProgram:
407+
def get_exported_program(
408+
self, graph: Optional[str] = EDGE_DIALECT_GRAPH_KEY
409+
) -> ExportedProgram:
358410
"""
359411
Access helper for ETRecord, defaults to returning Edge Dialect Program
412+
413+
Args:
414+
graph: Name of the graph to access, defaults to "edge_dialect_output/forward"
360415
"""
361-
if not graph:
362-
return self._etrecord["edge_dialect_output/forward"]
363-
else:
364-
return self._etrecord.get(graph)
416+
return self._etrecord.graph_map.get(graph)

sdk/etdb/tests/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ python_unittest(
66
name = "inspector_test",
77
srcs = ["inspector_test.py"],
88
deps = [
9+
"//executorch/exir:lib",
10+
"//executorch/sdk/edir:et_schema",
911
"//executorch/sdk/etdb:inspector",
12+
"//executorch/sdk/etrecord:etrecord",
13+
"//executorch/sdk/etrecord/tests:etrecord_test_library",
1014
],
1115
)
1216

@@ -18,3 +22,14 @@ python_unittest(
1822
"//executorch/sdk/etdump:schema_flatcc",
1923
],
2024
)
25+
26+
python_unittest(
27+
name = "inspector_utils_test",
28+
srcs = ["inspector_utils_test.py"],
29+
deps = [
30+
"//executorch/sdk/edir:et_schema",
31+
"//executorch/sdk/etdb:inspector_utils",
32+
"//executorch/sdk/etrecord:etrecord",
33+
"//executorch/sdk/etrecord/tests:etrecord_test_library",
34+
],
35+
)

0 commit comments

Comments
 (0)