Skip to content

Populate Event attributes with op nodes metadata linked by debug handles #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdk/etdb/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ python_library(
"//executorch/exir:lib",
"//executorch/sdk/edir:et_schema",
"//executorch/sdk/etdump:schema_flatcc",
"//executorch/sdk/etrecord:etrecord",
],
)

Expand All @@ -52,6 +51,8 @@ python_library(
],
deps = [
"//executorch/sdk/edir:et_schema",
"//executorch/sdk/etdump:schema_flatcc",
"//executorch/sdk/etdump:serialize",
"//executorch/sdk/etrecord:etrecord",
],
)
57 changes: 53 additions & 4 deletions sdk/etdb/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Mapping
from typing import Dict, Mapping, Optional

from executorch.sdk.edir.et_schema import FXOperatorGraph, OperatorGraphWithStats
from executorch.sdk.etrecord import ETRecord
from executorch.sdk.edir.et_schema import (
FXOperatorGraph,
OperatorGraphWithStats,
OperatorNode,
)
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC

from executorch.sdk.etdump.serialize import deserialize_from_etdump_flatcc
from executorch.sdk.etrecord import ETRecord, parse_etrecord


# TODO: add a unittest for this function
def gen_graphs_from_etrecord(
etrecord: ETRecord,
) -> Mapping[str, OperatorGraphWithStats]:
Expand All @@ -20,3 +26,46 @@ def gen_graphs_from_etrecord(
name: FXOperatorGraph.gen_operator_graph(exported_program.graph_module)
for name, exported_program in etrecord.graph_map.items()
}


# TODO: use anonymous function to avoid passing the dict around
# and move this inside of the OperatorGraphWithStats class
def create_debug_handle_to_op_node_mapping(
op_graph: OperatorGraphWithStats,
debug_handle_to_op_node_map: Dict[int, OperatorNode],
) -> None:
"""
Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
from each debug handle to the operator node that contains the debug handle in its metadata.
"""
# Recursively searches through the metadata of nodes
for element in op_graph.elements:
if isinstance(element, OperatorGraphWithStats):
create_debug_handle_to_op_node_mapping(element, debug_handle_to_op_node_map)
if isinstance(element, OperatorNode) and element.metadata is not None:
metadata = element.metadata
debug_handle = metadata.get("debug_handle")
if debug_handle is not None:
existing_entry = debug_handle_to_op_node_map.get(debug_handle)
if existing_entry is not None:
raise ValueError(
f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
"No two op nodes of the same graph should have the same debug handle."
)
debug_handle_to_op_node_map[debug_handle] = element


def gen_etrecord_object(etrecord_path: Optional[str] = None) -> ETRecord:
# Gen op graphs from etrecord
if etrecord_path is None:
raise ValueError("Etrecord_path must be specified.")
return parse_etrecord(etrecord_path=etrecord_path)


def gen_etdump_object(etdump_path: Optional[str] = None) -> ETDumpFlatCC:
# Gen event blocks from etdump
if etdump_path is None:
raise ValueError("Etdump_path must be specified.")
with open(etdump_path, "rb") as buff:
etdump = deserialize_from_etdump_flatcc(buff.read())
return etdump
171 changes: 137 additions & 34 deletions sdk/etdb/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,33 @@
import torch
from executorch.exir import ExportedProgram

from executorch.sdk.edir.et_schema import OperatorGraphWithStats
from executorch.sdk.etdb._inspector_utils import gen_graphs_from_etrecord
from executorch.sdk.edir.et_schema import OperatorGraphWithStats, OperatorNode
from executorch.sdk.etdb._inspector_utils import (
create_debug_handle_to_op_node_mapping,
gen_etdump_object,
gen_etrecord_object,
gen_graphs_from_etrecord,
)
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC, ProfileEvent
from executorch.sdk.etrecord import parse_etrecord

from tabulate import tabulate


FORWARD = "forward"
RESERVED_SPECIAL_EVENT_NAMES = [
"Method::init",
"Program::load_method",
"Method::execute",
]
EXCLUDED_COLUMNS_WHEN_PRINTING = [
"raw",
"delegate_debug_identifier",
"stack_traces",
"module_hierarchy",
"debug_data",
]


log: logging.Logger = logging.getLogger(__name__)

# Signature of a ProfileEvent
Expand Down Expand Up @@ -112,7 +133,7 @@ class Event:

name: str
perf_data: PerfData
op_type: List[str] = dataclasses.field(default_factory=list)
op_types: List[str] = dataclasses.field(default_factory=list)

# Instruction Id of the original profiling event
instruction_id: Optional[int] = None
Expand All @@ -123,7 +144,7 @@ class Event:
# Debug Handles in the model graph to which this event is correlated
debug_handles: Optional[Union[int, Sequence[int]]] = None

stack_trace: Dict[str, str] = dataclasses.field(default_factory=dict)
stack_traces: Dict[str, str] = dataclasses.field(default_factory=dict)
module_hierarchy: Dict[str, Dict] = dataclasses.field(default_factory=dict)
is_delegated_op: Optional[bool] = None
delegate_backend_name: Optional[str] = None
Expand All @@ -138,9 +159,10 @@ def _gen_from_profile_events(
return an Event object matching the ProfileEventSignature, with perf_data
populated from the list of ProfileEvents
"""
delegate_debug_identifier = (
signature.delegate_id or signature.delegate_id_str or None
)
if signature.delegate_id is not None: # 0 is a valid value
delegate_debug_identifier = signature.delegate_id
else:
delegate_debug_identifier = signature.delegate_id_str or None

# Use the delegate identifier as the event name if delegated
is_delegated_op = delegate_debug_identifier is not None
Expand All @@ -158,6 +180,28 @@ def _gen_from_profile_events(
is_delegated_op=is_delegated_op,
)

def _associate_with_op_graph_nodes(
self, debug_handle_to_op_node_map: Dict[int, OperatorNode]
) -> None:
"""
Helper function to populate the stack_traces, module_hierarchy and op_types attributes
based on the debug handles of this event
"""
if (debug_handles := self.debug_handles) is None:
return

if isinstance(debug_handles, int):
debug_handles = [debug_handles]

for handle in debug_handles:
node = debug_handle_to_op_node_map.get(handle)
if node is not None and (metadata := node.metadata) is not None:
self.stack_traces[node.name] = metadata.get("stack_trace")
self.module_hierarchy[node.name] = metadata.get("nn_module_stack")
if node.op:
# TODO: consider having this as a dict from node.name -> node.op
self.op_types += [node.op]


@dataclass
class EventBlock:
Expand Down Expand Up @@ -186,11 +230,11 @@ def to_dataframe(self) -> pd.DataFrame:
"min": [event.perf_data.min for event in self.events],
"max": [event.perf_data.max for event in self.events],
"median": [event.perf_data.median for event in self.events],
"op_type": [event.op_type for event in self.events],
"op_types": [event.op_types for event in self.events],
"delegate_debug_identifier": [
event.delegate_debug_identifier for event in self.events
],
"stack_traces": [event.stack_trace for event in self.events],
"stack_traces": [event.stack_traces for event in self.events],
"module_hierarchy": [event.module_hierarchy for event in self.events],
"is_delegated_op": [event.is_delegated_op for event in self.events],
"delegate_backend_name": [
Expand Down Expand Up @@ -248,10 +292,11 @@ def _gen_from_etdump(etdump: ETDumpFlatCC) -> List["EventBlock"]:
for index, profile_events in enumerate(profile_run_groups.values())
]

# TODO: Considering changing ETRecord deserialization logic to cast the ints in string format to actual ints
def _gen_resolve_debug_handles(
self,
handle_map: Dict[int, List[int]],
delegate_map: Optional[Dict[int, DelegateMetadata]] = None,
handle_map: Dict[str, List[int]],
delegate_map: Optional[Dict[str, DelegateMetadata]] = None,
):
"""
Given mappings from instruction id to debug handles, populate the
Expand All @@ -261,10 +306,12 @@ def _gen_resolve_debug_handles(
to obtain the debug_handle via the delegate map
"""
for event in self.events:
# Check if instruction_id is present in the event
if event.instruction_id is None:
continue

# Check for the instruction_id in handle map
if (
instruction_id := event.instruction_id
) is None or instruction_id not in handle_map:
if (instruction_id := str(event.instruction_id)) not in handle_map:
continue

# For non-delegated event, handles are found in handle_map
Expand All @@ -285,14 +332,33 @@ def _gen_resolve_debug_handles(

# For delegated events, handles are found via delegateMetadata
event.delegate_backend_name = delegate_metadata.get("name", "")
event.debug_handles = delegate_metadata.get("delegate_map", {}).get(
delegate_metadata_delegate_map = delegate_metadata.get("delegate_map", {})

# delegate_debug_id can be either int based or string based, therefore we need to check both
debug_handles = delegate_metadata_delegate_map.get(
delegate_debug_id # pyre-ignore
)
if debug_handles is not None:
event.debug_handles = debug_handles
else:
event.debug_handles = delegate_metadata_delegate_map.get(
str(delegate_debug_id) # pyre-ignore
)


EDGE_DIALECT_GRAPH_KEY = "edge_dialect_output/forward"


class Inspector:
"""
APIs for examining model architecture and performance stats
APIs for examining model architecture and performance stats.

Public Attributes:
event_blocks: List["EventBlocks"]. Structured data accessible through Inspector for analysis.

Private Attributes:
_etrecord: Optional[ETRecord]. File under etrecord_path deserialized into an object.
_op_graph_dict: Mapping[str, OperatorGraphWithStats]. Graph objects parsed from etrecord matched with user defined graph names.
"""

def __init__(
Expand All @@ -302,15 +368,34 @@ def __init__(
Create an inspector instance from the provided ETDump/ETRecord
"""

# Gen op graphs from etrecord
if etrecord_path is not None:
self._etrecord = parse_etrecord(etrecord_path=etrecord_path)
self._op_graph_dict: Mapping[
str, OperatorGraphWithStats
] = gen_graphs_from_etrecord(etrecord=self._etrecord)
# TODO: etrecord_path can be optional, so need to support the case when it is not present
self._etrecord = gen_etrecord_object(etrecord_path=etrecord_path)
etdump = gen_etdump_object(etdump_path=etdump_path)
self.event_blocks = EventBlock._gen_from_etdump(etdump)

self._op_graph_dict: Mapping[
str, OperatorGraphWithStats
] = gen_graphs_from_etrecord(etrecord=self._etrecord)

# Use the delegate map from etrecord, associate debug handles with each event
for event_block in self.event_blocks:
event_block._gen_resolve_debug_handles(
self._etrecord._debug_handle_map[FORWARD],
self._etrecord._delegate_map[FORWARD]
if self._etrecord._delegate_map is not None
else None,
)

self.event_blocks: List[EventBlock] = []
# TODO: create event blocks from etdump, and associate events with op graph nodes
# Traverse the edge dialect op graph to create mapping from debug_handle to op node
debug_handle_to_op_node_map = {}
create_debug_handle_to_op_node_mapping(
self._op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
debug_handle_to_op_node_map,
)

for event_block in self.event_blocks:
for event in event_block.events:
event._associate_with_op_graph_nodes(debug_handle_to_op_node_map)

def print_data_tabular(self) -> None:
"""
Expand All @@ -322,14 +407,31 @@ def style_text_size(val, size=12):

df_list = [event_block.to_dataframe() for event_block in self.event_blocks]
combined_df = pd.concat(df_list, ignore_index=True)
# TODO: filter out raw, delegate_debug_identifier, stack_traces and module_hierarchy
# Filter out some columns for better readability when printing
filtered_df = combined_df.drop(columns=EXCLUDED_COLUMNS_WHEN_PRINTING)
try:
from IPython.display import display

styled_df = combined_df.style.applymap(style_text_size)
styled_df = filtered_df.style.applymap(style_text_size)
display(styled_df)
except:
print(tabulate(combined_df, headers="keys", tablefmt="fancy_grid"))
# TODO: figure out how to trigger this path in python shell
print(tabulate(filtered_df, headers="keys", tablefmt="fancy_grid"))

# TODO: write unit test
def find_total_for_module(self, module_name: str):
total = 0.0
for block in self.event_blocks:
for event in block.events:
module_hierarchy = event.module_hierarchy.values()
for hierarchy in module_hierarchy:
if not hierarchy:
continue
found = any(module_name in key for key in hierarchy.keys())
if found:
total += event.perf_data.avg
break
return total

def get_event_blocks(self) -> List[EventBlock]:
"""
Expand All @@ -353,12 +455,13 @@ def write_tensorboard_artifact(self, path: str) -> None:
# TODO: implement
pass

# TODO: add a unittest for this function
def get_exported_program(self, graph: Optional[str]) -> ExportedProgram:
def get_exported_program(
self, graph: Optional[str] = EDGE_DIALECT_GRAPH_KEY
) -> ExportedProgram:
"""
Access helper for ETRecord, defaults to returning Edge Dialect Program

Args:
graph: Name of the graph to access, defaults to "edge_dialect_output/forward"
"""
if not graph:
return self._etrecord["edge_dialect_output/forward"]
else:
return self._etrecord.get(graph)
return self._etrecord.graph_map.get(graph)
15 changes: 15 additions & 0 deletions sdk/etdb/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ python_unittest(
name = "inspector_test",
srcs = ["inspector_test.py"],
deps = [
"//executorch/exir:lib",
"//executorch/sdk/edir:et_schema",
"//executorch/sdk/etdb:inspector",
"//executorch/sdk/etrecord:etrecord",
"//executorch/sdk/etrecord/tests:etrecord_test_library",
],
)

Expand All @@ -18,3 +22,14 @@ python_unittest(
"//executorch/sdk/etdump:schema_flatcc",
],
)

python_unittest(
name = "inspector_utils_test",
srcs = ["inspector_utils_test.py"],
deps = [
"//executorch/sdk/edir:et_schema",
"//executorch/sdk/etdb:inspector_utils",
"//executorch/sdk/etrecord:etrecord",
"//executorch/sdk/etrecord/tests:etrecord_test_library",
],
)
Loading