Skip to content

Commit c872580

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
Populate Event attributes with op nodes metadata linked by debug handles (#401)
Summary: Pull Request resolved: #401 The main logic change is in the __init__() of Inspector. Reviewed By: Jack-Khuu Differential Revision: D49326221 fbshipit-source-id: 2f16b87f5eb802adb0b2675eb247c06b340bf6c0
1 parent 1cc64fe commit c872580

File tree

8 files changed

+523
-66
lines changed

8 files changed

+523
-66
lines changed

sdk/etdb/TARGETS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ python_library(
4141
"//executorch/exir:lib",
4242
"//executorch/sdk/edir:et_schema",
4343
"//executorch/sdk/etdump:schema_flatcc",
44-
"//executorch/sdk/etrecord:etrecord",
4544
],
4645
)
4746

@@ -52,6 +51,8 @@ python_library(
5251
],
5352
deps = [
5453
"//executorch/sdk/edir:et_schema",
54+
"//executorch/sdk/etdump:schema_flatcc",
55+
"//executorch/sdk/etdump:serialize",
5556
"//executorch/sdk/etrecord:etrecord",
5657
],
5758
)

sdk/etdb/_inspector_utils.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Mapping
7+
from typing import Dict, Mapping, Optional
88

9-
from executorch.sdk.edir.et_schema import FXOperatorGraph, OperatorGraphWithStats
10-
from executorch.sdk.etrecord import ETRecord
9+
from executorch.sdk.edir.et_schema import (
10+
FXOperatorGraph,
11+
OperatorGraphWithStats,
12+
OperatorNode,
13+
)
14+
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC
15+
16+
from executorch.sdk.etdump.serialize import deserialize_from_etdump_flatcc
17+
from executorch.sdk.etrecord import ETRecord, parse_etrecord
1118

1219

13-
# TODO: add a unittest for this function
1420
def gen_graphs_from_etrecord(
1521
etrecord: ETRecord,
1622
) -> Mapping[str, OperatorGraphWithStats]:
@@ -20,3 +26,46 @@ def gen_graphs_from_etrecord(
2026
name: FXOperatorGraph.gen_operator_graph(exported_program.graph_module)
2127
for name, exported_program in etrecord.graph_map.items()
2228
}
29+
30+
31+
# TODO: use anonymous function to avoid passing the dict around
32+
# and move this inside of the OperatorGraphWithStats class
33+
def create_debug_handle_to_op_node_mapping(
34+
op_graph: OperatorGraphWithStats,
35+
debug_handle_to_op_node_map: Dict[int, OperatorNode],
36+
) -> None:
37+
"""
38+
Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
39+
from each debug handle to the operator node that contains the debug handle in its metadata.
40+
"""
41+
# Recursively searches through the metadata of nodes
42+
for element in op_graph.elements:
43+
if isinstance(element, OperatorGraphWithStats):
44+
create_debug_handle_to_op_node_mapping(element, debug_handle_to_op_node_map)
45+
if isinstance(element, OperatorNode) and element.metadata is not None:
46+
metadata = element.metadata
47+
debug_handle = metadata.get("debug_handle")
48+
if debug_handle is not None:
49+
existing_entry = debug_handle_to_op_node_map.get(debug_handle)
50+
if existing_entry is not None:
51+
raise ValueError(
52+
f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
53+
"No two op nodes of the same graph should have the same debug handle."
54+
)
55+
debug_handle_to_op_node_map[debug_handle] = element
56+
57+
58+
def gen_etrecord_object(etrecord_path: Optional[str] = None) -> ETRecord:
59+
# Gen op graphs from etrecord
60+
if etrecord_path is None:
61+
raise ValueError("Etrecord_path must be specified.")
62+
return parse_etrecord(etrecord_path=etrecord_path)
63+
64+
65+
def gen_etdump_object(etdump_path: Optional[str] = None) -> ETDumpFlatCC:
66+
# Gen event blocks from etdump
67+
if etdump_path is None:
68+
raise ValueError("Etdump_path must be specified.")
69+
with open(etdump_path, "rb") as buff:
70+
etdump = deserialize_from_etdump_flatcc(buff.read())
71+
return etdump

sdk/etdb/inspector.py

Lines changed: 137 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,33 @@
2525
import torch
2626
from executorch.exir import ExportedProgram
2727

28-
from executorch.sdk.edir.et_schema import OperatorGraphWithStats
29-
from executorch.sdk.etdb._inspector_utils import gen_graphs_from_etrecord
28+
from executorch.sdk.edir.et_schema import OperatorGraphWithStats, OperatorNode
29+
from executorch.sdk.etdb._inspector_utils import (
30+
create_debug_handle_to_op_node_mapping,
31+
gen_etdump_object,
32+
gen_etrecord_object,
33+
gen_graphs_from_etrecord,
34+
)
3035
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC, ProfileEvent
31-
from executorch.sdk.etrecord import parse_etrecord
36+
3237
from tabulate import tabulate
3338

39+
40+
FORWARD = "forward"
41+
RESERVED_SPECIAL_EVENT_NAMES = [
42+
"Method::init",
43+
"Program::load_method",
44+
"Method::execute",
45+
]
46+
EXCLUDED_COLUMNS_WHEN_PRINTING = [
47+
"raw",
48+
"delegate_debug_identifier",
49+
"stack_traces",
50+
"module_hierarchy",
51+
"debug_data",
52+
]
53+
54+
3455
log: logging.Logger = logging.getLogger(__name__)
3556

3657
# Signature of a ProfileEvent
@@ -112,7 +133,7 @@ class Event:
112133

113134
name: str
114135
perf_data: PerfData
115-
op_type: List[str] = dataclasses.field(default_factory=list)
136+
op_types: List[str] = dataclasses.field(default_factory=list)
116137

117138
# Instruction Id of the original profiling event
118139
instruction_id: Optional[int] = None
@@ -123,7 +144,7 @@ class Event:
123144
# Debug Handles in the model graph to which this event is correlated
124145
debug_handles: Optional[Union[int, Sequence[int]]] = None
125146

126-
stack_trace: Dict[str, str] = dataclasses.field(default_factory=dict)
147+
stack_traces: Dict[str, str] = dataclasses.field(default_factory=dict)
127148
module_hierarchy: Dict[str, Dict] = dataclasses.field(default_factory=dict)
128149
is_delegated_op: Optional[bool] = None
129150
delegate_backend_name: Optional[str] = None
@@ -138,9 +159,10 @@ def _gen_from_profile_events(
138159
return an Event object matching the ProfileEventSignature, with perf_data
139160
populated from the list of ProfileEvents
140161
"""
141-
delegate_debug_identifier = (
142-
signature.delegate_id or signature.delegate_id_str or None
143-
)
162+
if signature.delegate_id is not None: # 0 is a valid value
163+
delegate_debug_identifier = signature.delegate_id
164+
else:
165+
delegate_debug_identifier = signature.delegate_id_str or None
144166

145167
# Use the delegate identifier as the event name if delegated
146168
is_delegated_op = delegate_debug_identifier is not None
@@ -158,6 +180,28 @@ def _gen_from_profile_events(
158180
is_delegated_op=is_delegated_op,
159181
)
160182

183+
def _associate_with_op_graph_nodes(
184+
self, debug_handle_to_op_node_map: Dict[int, OperatorNode]
185+
) -> None:
186+
"""
187+
Helper function to populate the stack_traces, module_hierarchy and op_types attributes
188+
based on the debug handles of this event
189+
"""
190+
if (debug_handles := self.debug_handles) is None:
191+
return
192+
193+
if isinstance(debug_handles, int):
194+
debug_handles = [debug_handles]
195+
196+
for handle in debug_handles:
197+
node = debug_handle_to_op_node_map.get(handle)
198+
if node is not None and (metadata := node.metadata) is not None:
199+
self.stack_traces[node.name] = metadata.get("stack_trace")
200+
self.module_hierarchy[node.name] = metadata.get("nn_module_stack")
201+
if node.op:
202+
# TODO: consider having this as a dict from node.name -> node.op
203+
self.op_types += [node.op]
204+
161205

162206
@dataclass
163207
class EventBlock:
@@ -186,11 +230,11 @@ def to_dataframe(self) -> pd.DataFrame:
186230
"min": [event.perf_data.min for event in self.events],
187231
"max": [event.perf_data.max for event in self.events],
188232
"median": [event.perf_data.median for event in self.events],
189-
"op_type": [event.op_type for event in self.events],
233+
"op_types": [event.op_types for event in self.events],
190234
"delegate_debug_identifier": [
191235
event.delegate_debug_identifier for event in self.events
192236
],
193-
"stack_traces": [event.stack_trace for event in self.events],
237+
"stack_traces": [event.stack_traces for event in self.events],
194238
"module_hierarchy": [event.module_hierarchy for event in self.events],
195239
"is_delegated_op": [event.is_delegated_op for event in self.events],
196240
"delegate_backend_name": [
@@ -248,10 +292,11 @@ def _gen_from_etdump(etdump: ETDumpFlatCC) -> List["EventBlock"]:
248292
for index, profile_events in enumerate(profile_run_groups.values())
249293
]
250294

295+
# TODO: Considering changing ETRecord deserialization logic to cast the ints in string format to actual ints
251296
def _gen_resolve_debug_handles(
252297
self,
253-
handle_map: Dict[int, List[int]],
254-
delegate_map: Optional[Dict[int, DelegateMetadata]] = None,
298+
handle_map: Dict[str, List[int]],
299+
delegate_map: Optional[Dict[str, DelegateMetadata]] = None,
255300
):
256301
"""
257302
Given mappings from instruction id to debug handles, populate the
@@ -261,10 +306,12 @@ def _gen_resolve_debug_handles(
261306
to obtain the debug_handle via the delegate map
262307
"""
263308
for event in self.events:
309+
# Check if instruction_id is present in the event
310+
if event.instruction_id is None:
311+
continue
312+
264313
# Check for the instruction_id in handle map
265-
if (
266-
instruction_id := event.instruction_id
267-
) is None or instruction_id not in handle_map:
314+
if (instruction_id := str(event.instruction_id)) not in handle_map:
268315
continue
269316

270317
# For non-delegated event, handles are found in handle_map
@@ -285,14 +332,33 @@ def _gen_resolve_debug_handles(
285332

286333
# For delegated events, handles are found via delegateMetadata
287334
event.delegate_backend_name = delegate_metadata.get("name", "")
288-
event.debug_handles = delegate_metadata.get("delegate_map", {}).get(
335+
delegate_metadata_delegate_map = delegate_metadata.get("delegate_map", {})
336+
337+
# delegate_debug_id can be either int based or string based, therefore we need to check both
338+
debug_handles = delegate_metadata_delegate_map.get(
289339
delegate_debug_id # pyre-ignore
290340
)
341+
if debug_handles is not None:
342+
event.debug_handles = debug_handles
343+
else:
344+
event.debug_handles = delegate_metadata_delegate_map.get(
345+
str(delegate_debug_id) # pyre-ignore
346+
)
347+
348+
349+
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_output/forward"
291350

292351

293352
class Inspector:
294353
"""
295-
APIs for examining model architecture and performance stats
354+
APIs for examining model architecture and performance stats.
355+
356+
Public Attributes:
357+
event_blocks: List["EventBlocks"]. Structured data accessible through Inspector for analysis.
358+
359+
Private Attributes:
360+
_etrecord: Optional[ETRecord]. File under etrecord_path deserialized into an object.
361+
_op_graph_dict: Mapping[str, OperatorGraphWithStats]. Graph objects parsed from etrecord matched with user defined graph names.
296362
"""
297363

298364
def __init__(
@@ -302,15 +368,34 @@ def __init__(
302368
Create an inspector instance from the provided ETDump/ETRecord
303369
"""
304370

305-
# Gen op graphs from etrecord
306-
if etrecord_path is not None:
307-
self._etrecord = parse_etrecord(etrecord_path=etrecord_path)
308-
self._op_graph_dict: Mapping[
309-
str, OperatorGraphWithStats
310-
] = gen_graphs_from_etrecord(etrecord=self._etrecord)
371+
# TODO: etrecord_path can be optional, so need to support the case when it is not present
372+
self._etrecord = gen_etrecord_object(etrecord_path=etrecord_path)
373+
etdump = gen_etdump_object(etdump_path=etdump_path)
374+
self.event_blocks = EventBlock._gen_from_etdump(etdump)
375+
376+
self._op_graph_dict: Mapping[
377+
str, OperatorGraphWithStats
378+
] = gen_graphs_from_etrecord(etrecord=self._etrecord)
379+
380+
# Use the delegate map from etrecord, associate debug handles with each event
381+
for event_block in self.event_blocks:
382+
event_block._gen_resolve_debug_handles(
383+
self._etrecord._debug_handle_map[FORWARD],
384+
self._etrecord._delegate_map[FORWARD]
385+
if self._etrecord._delegate_map is not None
386+
else None,
387+
)
311388

312-
self.event_blocks: List[EventBlock] = []
313-
# TODO: create event blocks from etdump, and associate events with op graph nodes
389+
# Traverse the edge dialect op graph to create mapping from debug_handle to op node
390+
debug_handle_to_op_node_map = {}
391+
create_debug_handle_to_op_node_mapping(
392+
self._op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
393+
debug_handle_to_op_node_map,
394+
)
395+
396+
for event_block in self.event_blocks:
397+
for event in event_block.events:
398+
event._associate_with_op_graph_nodes(debug_handle_to_op_node_map)
314399

315400
def print_data_tabular(self) -> None:
316401
"""
@@ -322,14 +407,31 @@ def style_text_size(val, size=12):
322407

323408
df_list = [event_block.to_dataframe() for event_block in self.event_blocks]
324409
combined_df = pd.concat(df_list, ignore_index=True)
325-
# TODO: filter out raw, delegate_debug_identifier, stack_traces and module_hierarchy
410+
# Filter out some columns for better readability when printing
411+
filtered_df = combined_df.drop(columns=EXCLUDED_COLUMNS_WHEN_PRINTING)
326412
try:
327413
from IPython.display import display
328414

329-
styled_df = combined_df.style.applymap(style_text_size)
415+
styled_df = filtered_df.style.applymap(style_text_size)
330416
display(styled_df)
331417
except:
332-
print(tabulate(combined_df, headers="keys", tablefmt="fancy_grid"))
418+
# TODO: figure out how to trigger this path in python shell
419+
print(tabulate(filtered_df, headers="keys", tablefmt="fancy_grid"))
420+
421+
# TODO: write unit test
422+
def find_total_for_module(self, module_name: str):
423+
total = 0.0
424+
for block in self.event_blocks:
425+
for event in block.events:
426+
module_hierarchy = event.module_hierarchy.values()
427+
for hierarchy in module_hierarchy:
428+
if not hierarchy:
429+
continue
430+
found = any(module_name in key for key in hierarchy.keys())
431+
if found:
432+
total += event.perf_data.avg
433+
break
434+
return total
333435

334436
def get_event_blocks(self) -> List[EventBlock]:
335437
"""
@@ -353,12 +455,13 @@ def write_tensorboard_artifact(self, path: str) -> None:
353455
# TODO: implement
354456
pass
355457

356-
# TODO: add a unittest for this function
357-
def get_exported_program(self, graph: Optional[str]) -> ExportedProgram:
458+
def get_exported_program(
459+
self, graph: Optional[str] = EDGE_DIALECT_GRAPH_KEY
460+
) -> ExportedProgram:
358461
"""
359462
Access helper for ETRecord, defaults to returning Edge Dialect Program
463+
464+
Args:
465+
graph: Name of the graph to access, defaults to "edge_dialect_output/forward"
360466
"""
361-
if not graph:
362-
return self._etrecord["edge_dialect_output/forward"]
363-
else:
364-
return self._etrecord.get(graph)
467+
return self._etrecord.graph_map.get(graph)

sdk/etdb/tests/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ python_unittest(
66
name = "inspector_test",
77
srcs = ["inspector_test.py"],
88
deps = [
9+
"//executorch/exir:lib",
10+
"//executorch/sdk/edir:et_schema",
911
"//executorch/sdk/etdb:inspector",
12+
"//executorch/sdk/etrecord:etrecord",
13+
"//executorch/sdk/etrecord/tests:etrecord_test_library",
1014
],
1115
)
1216

@@ -18,3 +22,14 @@ python_unittest(
1822
"//executorch/sdk/etdump:schema_flatcc",
1923
],
2024
)
25+
26+
python_unittest(
27+
name = "inspector_utils_test",
28+
srcs = ["inspector_utils_test.py"],
29+
deps = [
30+
"//executorch/sdk/edir:et_schema",
31+
"//executorch/sdk/etdb:inspector_utils",
32+
"//executorch/sdk/etrecord:etrecord",
33+
"//executorch/sdk/etrecord/tests:etrecord_test_library",
34+
],
35+
)

0 commit comments

Comments
 (0)