Skip to content

Commit 4f1119f

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
Populate Event attributes with op nodes metadata linked by debug handles (#401)
Summary: Pull Request resolved: #401 The main logic change is in the __init__() of Inspector. Differential Revision: D49326221 fbshipit-source-id: 89642a48dac15fe7cd6d727200b4661c7e707c28
1 parent c0123ec commit 4f1119f

File tree

8 files changed

+516
-64
lines changed

8 files changed

+516
-64
lines changed

sdk/etdb/TARGETS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ python_library(
4141
"//executorch/exir:lib",
4242
"//executorch/sdk/edir:et_schema",
4343
"//executorch/sdk/etdump:schema_flatcc",
44-
"//executorch/sdk/etrecord:etrecord",
4544
],
4645
)
4746

@@ -52,6 +51,8 @@ python_library(
5251
],
5352
deps = [
5453
"//executorch/sdk/edir:et_schema",
54+
"//executorch/sdk/etdump:schema_flatcc",
55+
"//executorch/sdk/etdump:serialize",
5556
"//executorch/sdk/etrecord:etrecord",
5657
],
5758
)

sdk/etdb/_inspector_utils.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Mapping
7+
from typing import Dict, Mapping, Optional
88

9-
from executorch.sdk.edir.et_schema import FXOperatorGraph, OperatorGraphWithStats
10-
from executorch.sdk.etrecord import ETRecord
9+
from executorch.sdk.edir.et_schema import (
10+
FXOperatorGraph,
11+
OperatorGraphWithStats,
12+
OperatorNode,
13+
)
14+
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC
15+
16+
from executorch.sdk.etdump.serialize import deserialize_from_etdump_flatcc
17+
from executorch.sdk.etrecord import ETRecord, parse_etrecord
1118

1219

13-
# TODO: add a unittest for this function
1420
def gen_graphs_from_etrecord(
1521
etrecord: ETRecord,
1622
) -> Mapping[str, OperatorGraphWithStats]:
@@ -20,3 +26,46 @@ def gen_graphs_from_etrecord(
2026
name: FXOperatorGraph.gen_operator_graph(exported_program.graph_module)
2127
for name, exported_program in etrecord.graph_map.items()
2228
}
29+
30+
31+
# TODO: use anonymous function to avoid passing the dict around
32+
# and move this inside of the OperatorGraphWithStats class
33+
def create_debug_handle_to_op_node_mapping(
34+
op_graph: OperatorGraphWithStats,
35+
debug_handle_to_op_node_map: Dict[int, OperatorNode],
36+
) -> None:
37+
"""
38+
Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
39+
from each debug handle to the operator node that contains the debug handle in its metadata.
40+
"""
41+
# Recursively searches through the metadata of nodes
42+
for element in op_graph.elements:
43+
if isinstance(element, OperatorGraphWithStats):
44+
create_debug_handle_to_op_node_mapping(element, debug_handle_to_op_node_map)
45+
if isinstance(element, OperatorNode) and element.metadata is not None:
46+
metadata = element.metadata
47+
debug_handle = metadata.get("debug_handle")
48+
if debug_handle is not None:
49+
existing_entry = debug_handle_to_op_node_map.get(debug_handle)
50+
if existing_entry is not None:
51+
raise ValueError(
52+
f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
53+
"No two op nodes of the same graph should have the same debug handle."
54+
)
55+
debug_handle_to_op_node_map[debug_handle] = element
56+
57+
58+
def gen_etrecord_object(etrecord_path: Optional[str] = None) -> ETRecord:
59+
# Gen op graphs from etrecord
60+
if etrecord_path is None:
61+
raise ValueError("Etrecord_path must be specified.")
62+
return parse_etrecord(etrecord_path=etrecord_path)
63+
64+
65+
def gen_etdump_object(etdump_path: Optional[str] = None) -> ETDumpFlatCC:
66+
# Gen event blocks from etdump
67+
if etdump_path is None:
68+
raise ValueError("Etdump_path must be specified.")
69+
with open(etdump_path, "rb") as buff:
70+
etdump = deserialize_from_etdump_flatcc(buff.read())
71+
return etdump

sdk/etdb/inspector.py

Lines changed: 130 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,26 @@
2525
import torch
2626
from executorch.exir import ExportedProgram
2727

28-
from executorch.sdk.edir.et_schema import OperatorGraphWithStats
29-
from executorch.sdk.etdb._inspector_utils import gen_graphs_from_etrecord
28+
from executorch.sdk.edir.et_schema import OperatorGraphWithStats, OperatorNode
29+
from executorch.sdk.etdb._inspector_utils import (
30+
create_debug_handle_to_op_node_mapping,
31+
gen_etdump_object,
32+
gen_etrecord_object,
33+
gen_graphs_from_etrecord,
34+
)
3035
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC, ProfileEvent
31-
from executorch.sdk.etrecord import parse_etrecord
36+
3237
from tabulate import tabulate
3338

39+
40+
FORWARD = "forward"
41+
RESERVED_SPECIAL_EVENT_NAMES = [
42+
"Method::init",
43+
"Program::load_method",
44+
"Method::execute",
45+
]
46+
47+
3448
log: logging.Logger = logging.getLogger(__name__)
3549

3650
# Signature of a ProfileEvent
@@ -112,7 +126,7 @@ class Event:
112126

113127
name: str
114128
perf_data: PerfData
115-
op_type: List[str] = dataclasses.field(default_factory=list)
129+
op_types: List[str] = dataclasses.field(default_factory=list)
116130

117131
# Instruction Id of the original profiling event
118132
instruction_id: Optional[int] = None
@@ -123,7 +137,7 @@ class Event:
123137
# Debug Handles in the model graph to which this event is correlated
124138
debug_handles: Optional[Union[int, Sequence[int]]] = None
125139

126-
stack_trace: Dict[str, str] = dataclasses.field(default_factory=dict)
140+
stack_traces: Dict[str, str] = dataclasses.field(default_factory=dict)
127141
module_hierarchy: Dict[str, Dict] = dataclasses.field(default_factory=dict)
128142
is_delegated_op: Optional[bool] = None
129143
delegate_backend_name: Optional[str] = None
@@ -138,9 +152,10 @@ def _gen_from_profile_events(
138152
return an Event object matching the ProfileEventSignature, with perf_data
139153
populated from the list of ProfileEvents
140154
"""
141-
delegate_debug_identifier = (
142-
signature.delegate_id or signature.delegate_id_str or None
143-
)
155+
if signature.delegate_id is not None: # 0 is a valid value
156+
delegate_debug_identifier = signature.delegate_id
157+
else:
158+
delegate_debug_identifier = signature.delegate_id_str or None
144159

145160
# Use the delegate identifier as the event name if delegated
146161
is_delegated_op = delegate_debug_identifier is not None
@@ -158,6 +173,28 @@ def _gen_from_profile_events(
158173
is_delegated_op=is_delegated_op,
159174
)
160175

176+
def _associate_with_op_graph_nodes(
177+
self, debug_handle_to_op_node_map: Dict[int, OperatorNode]
178+
) -> None:
179+
"""
180+
Helper function to populate the stack_traces, module_hierarchy and op_types attributes
181+
based on the debug handles of this event
182+
"""
183+
if (debug_handles := self.debug_handles) is None:
184+
return
185+
186+
if isinstance(debug_handles, int):
187+
debug_handles = [debug_handles]
188+
189+
for handle in debug_handles:
190+
node = debug_handle_to_op_node_map.get(handle)
191+
if node is not None and (metadata := node.metadata) is not None:
192+
self.stack_traces[node.name] = metadata.get("stack_trace")
193+
self.module_hierarchy[node.name] = metadata.get("nn_module_stack")
194+
if node.op:
195+
# TODO: consider having this as a dict from node.name -> node.op
196+
self.op_types += [node.op]
197+
161198

162199
@dataclass
163200
class EventBlock:
@@ -186,11 +223,11 @@ def to_dataframe(self) -> pd.DataFrame:
186223
"min": [event.perf_data.min for event in self.events],
187224
"max": [event.perf_data.max for event in self.events],
188225
"median": [event.perf_data.median for event in self.events],
189-
"op_type": [event.op_type for event in self.events],
226+
"op_types": [event.op_types for event in self.events],
190227
"delegate_debug_identifier": [
191228
event.delegate_debug_identifier for event in self.events
192229
],
193-
"stack_traces": [event.stack_trace for event in self.events],
230+
"stack_traces": [event.stack_traces for event in self.events],
194231
"module_hierarchy": [event.module_hierarchy for event in self.events],
195232
"is_delegated_op": [event.is_delegated_op for event in self.events],
196233
"delegate_backend_name": [
@@ -250,8 +287,8 @@ def _gen_from_etdump(etdump: ETDumpFlatCC) -> List["EventBlock"]:
250287

251288
def _gen_resolve_debug_handles(
252289
self,
253-
handle_map: Dict[int, List[int]],
254-
delegate_map: Optional[Dict[int, DelegateMetadata]] = None,
290+
handle_map: Dict[str, List[int]],
291+
delegate_map: Optional[Dict[str, DelegateMetadata]] = None,
255292
):
256293
"""
257294
Given mappings from instruction id to debug handles, populate the
@@ -263,7 +300,7 @@ def _gen_resolve_debug_handles(
263300
for event in self.events:
264301
# Check for the instruction_id in handle map
265302
if (
266-
instruction_id := event.instruction_id
303+
instruction_id := str(event.instruction_id)
267304
) is None or instruction_id not in handle_map:
268305
continue
269306

@@ -285,14 +322,31 @@ def _gen_resolve_debug_handles(
285322

286323
# For delegated events, handles are found via delegateMetadata
287324
event.delegate_backend_name = delegate_metadata.get("name", "")
288-
event.debug_handles = delegate_metadata.get("delegate_map", {}).get(
325+
delegate_metadata_delegate_map = delegate_metadata.get("delegate_map", {})
326+
debug_handles = delegate_metadata_delegate_map.get(
289327
delegate_debug_id # pyre-ignore
290328
)
329+
if debug_handles is not None:
330+
event.debug_handles = debug_handles
331+
else:
332+
event.debug_handles = delegate_metadata_delegate_map.get(
333+
str(delegate_debug_id) # pyre-ignore
334+
)
335+
336+
337+
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_output/forward"
291338

292339

293340
class Inspector:
294341
"""
295-
APIs for examining model architecture and performance stats
342+
APIs for examining model architecture and performance stats.
343+
344+
Public Attributes:
345+
event_blocks: List["EventBlocks"]. Structured data accessible through Inspector for analysis.
346+
347+
Private Attributes:
348+
_etrecord: ETRecord. File under etrecord_path deserialized into an object.
349+
_op_graph_dict: Mapping[str, OperatorGraphWithStats]. Graph objects parsed from etrecord matched with user defined graph names.
296350
"""
297351

298352
def __init__(
@@ -302,15 +356,34 @@ def __init__(
302356
Create an inspector instance from the provided ETDump/ETRecord
303357
"""
304358

305-
# Gen op graphs from etrecord
306-
if etrecord_path is not None:
307-
self._etrecord = parse_etrecord(etrecord_path=etrecord_path)
308-
self._op_graph_dict: Mapping[
309-
str, OperatorGraphWithStats
310-
] = gen_graphs_from_etrecord(etrecord=self._etrecord)
359+
self._etrecord = gen_etrecord_object(etrecord_path=etrecord_path)
360+
etdump = gen_etdump_object(etdump_path=etdump_path)
361+
362+
self.event_blocks = EventBlock._gen_from_etdump(etdump)
363+
364+
self._op_graph_dict: Mapping[
365+
str, OperatorGraphWithStats
366+
] = gen_graphs_from_etrecord(etrecord=self._etrecord)
311367

312-
self.event_blocks: List[EventBlock] = []
313-
# TODO: create event blocks from etdump, and associate events with op graph nodes
368+
# Use the delegate map from etrecord, associate debug handles with each event
369+
for event_block in self.event_blocks:
370+
event_block._gen_resolve_debug_handles(
371+
self._etrecord._debug_handle_map[FORWARD],
372+
self._etrecord._delegate_map[FORWARD]
373+
if self._etrecord._delegate_map is not None
374+
else None,
375+
)
376+
377+
# Traverse the edge dialect op graph to create mapping from debug_handle to op node
378+
debug_handle_to_op_node_map = {}
379+
create_debug_handle_to_op_node_mapping(
380+
self._op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
381+
debug_handle_to_op_node_map,
382+
)
383+
384+
for event_block in self.event_blocks:
385+
for event in event_block.events:
386+
event._associate_with_op_graph_nodes(debug_handle_to_op_node_map)
314387

315388
def print_data_tabular(self) -> None:
316389
"""
@@ -322,14 +395,38 @@ def style_text_size(val, size=12):
322395

323396
df_list = [event_block.to_dataframe() for event_block in self.event_blocks]
324397
combined_df = pd.concat(df_list, ignore_index=True)
325-
# TODO: filter out raw, delegate_debug_identifier, stack_traces and module_hierarchy
398+
# Filter out raw, delegate_debug_identifier, stack_traces, module_hierarchy and debug_data for better readability
399+
columns_to_drop = [
400+
"raw",
401+
"delegate_debug_identifier",
402+
"stack_traces",
403+
"module_hierarchy",
404+
"debug_data",
405+
]
406+
# Drop the specified columns
407+
filtered_df = combined_df.drop(columns=columns_to_drop)
326408
try:
327409
from IPython.display import display
328410

329-
styled_df = combined_df.style.applymap(style_text_size)
411+
styled_df = filtered_df.style.applymap(style_text_size)
330412
display(styled_df)
331413
except:
332-
print(tabulate(combined_df, headers="keys", tablefmt="fancy_grid"))
414+
print(tabulate(filtered_df, headers="keys", tablefmt="fancy_grid"))
415+
416+
# TODO: write unit test
417+
def find_total_for_module(self, module_name: str):
418+
total = 0.0
419+
for block in self.event_blocks:
420+
for event in block.events:
421+
module_hierarchy = event.module_hierarchy.values()
422+
for hierarchy in module_hierarchy:
423+
if not hierarchy:
424+
continue
425+
found = any(module_name in key for key in hierarchy.keys())
426+
if found:
427+
total += event.perf_data.avg
428+
break
429+
return total
333430

334431
def get_event_blocks(self) -> List[EventBlock]:
335432
"""
@@ -353,12 +450,13 @@ def write_tensorboard_artifact(self, path: str) -> None:
353450
# TODO: implement
354451
pass
355452

356-
# TODO: add a unittest for this function
357-
def get_exported_program(self, graph: Optional[str]) -> ExportedProgram:
453+
def get_exported_program(
454+
self, graph: Optional[str] = EDGE_DIALECT_GRAPH_KEY
455+
) -> ExportedProgram:
358456
"""
359457
Access helper for ETRecord, defaults to returning Edge Dialect Program
458+
459+
Args:
460+
graph: Name of the graph to access, defaults to "edge_dialect_output/forward"
360461
"""
361-
if not graph:
362-
return self._etrecord["edge_dialect_output/forward"]
363-
else:
364-
return self._etrecord.get(graph)
462+
return self._etrecord.graph_map.get(graph)

sdk/etdb/tests/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ python_unittest(
66
name = "inspector_test",
77
srcs = ["inspector_test.py"],
88
deps = [
9+
"//executorch/exir:lib",
10+
"//executorch/sdk/edir:et_schema",
911
"//executorch/sdk/etdb:inspector",
12+
"//executorch/sdk/etrecord:etrecord",
13+
"//executorch/sdk/etrecord/tests:etrecord_test_library",
1014
],
1115
)
1216

@@ -18,3 +22,14 @@ python_unittest(
1822
"//executorch/sdk/etdump:schema_flatcc",
1923
],
2024
)
25+
26+
python_unittest(
27+
name = "inspector_utils_test",
28+
srcs = ["inspector_utils_test.py"],
29+
deps = [
30+
"//executorch/sdk/edir:et_schema",
31+
"//executorch/sdk/etdb:inspector_utils",
32+
"//executorch/sdk/etrecord:etrecord",
33+
"//executorch/sdk/etrecord/tests:etrecord_test_library",
34+
],
35+
)

0 commit comments

Comments
 (0)