Skip to content

Make etrecord path optional #572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdk/inspector/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ python_library(
":inspector_utils",
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/sdk/edir:base_schema",
"//executorch/sdk/edir:et_schema",
"//executorch/sdk/etdump:schema_flatcc",
"//executorch/sdk/etrecord:etrecord",
],
)

Expand Down
9 changes: 1 addition & 8 deletions sdk/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC

from executorch.sdk.etdump.serialize import deserialize_from_etdump_flatcc
from executorch.sdk.etrecord import ETRecord, parse_etrecord
from executorch.sdk.etrecord import ETRecord

EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"

Expand Down Expand Up @@ -61,13 +61,6 @@ def create_debug_handle_to_op_node_mapping(
debug_handle_to_op_node_map[debug_handle] = element


def gen_etrecord_object(etrecord_path: Optional[str] = None) -> ETRecord:
# Gen op graphs from etrecord
if etrecord_path is None:
raise ValueError("Etrecord_path must be specified.")
return parse_etrecord(etrecord_path=etrecord_path)


def gen_etdump_object(etdump_path: Optional[str] = None) -> ETDumpFlatCC:
# Gen event blocks from etdump
if etdump_path is None:
Expand Down
39 changes: 26 additions & 13 deletions sdk/inspector/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import torch
from executorch.exir import ExportedProgram

from executorch.sdk.edir.base_schema import OperatorGraph, OperatorNode
from executorch.sdk.edir.et_schema import OperatorNode
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC, ProfileEvent
from executorch.sdk.etrecord import parse_etrecord
from executorch.sdk.inspector._inspector_utils import (
create_debug_handle_to_op_node_mapping,
EDGE_DIALECT_GRAPH_KEY,
gen_etdump_object,
gen_etrecord_object,
gen_graphs_from_etrecord,
)

Expand Down Expand Up @@ -368,7 +368,6 @@ class Inspector:

Private Attributes:
_etrecord: Optional[ETRecord]. File under etrecord_path deserialized into an object.
_op_graph_dict: Mapping[str, OperatorGraphWithStats]. Graph objects parsed from etrecord matched with user defined graph names.
"""

def __init__(
Expand All @@ -387,14 +386,18 @@ def __init__(
defaults to milli (1000ms = 1s).
"""

# TODO: etrecord_path can be optional, so need to support the case when it is not present
self._etrecord = gen_etrecord_object(etrecord_path=etrecord_path)
self._etrecord = (
parse_etrecord(etrecord_path=etrecord_path)
if etrecord_path is not None
else None
)

etdump = gen_etdump_object(etdump_path=etdump_path)
self.event_blocks = EventBlock._gen_from_etdump(etdump, etdump_scale)

self._op_graph_dict: Mapping[str, OperatorGraph] = gen_graphs_from_etrecord(
etrecord=self._etrecord
)
# No additional data association can be done without ETRecord, so return early
if self._etrecord is None:
return

# Use the delegate map from etrecord, associate debug handles with each event
for event_block in self.event_blocks:
Expand All @@ -406,9 +409,10 @@ def __init__(
)

# Traverse the edge dialect op graph to create mapping from debug_handle to op node
op_graph_dict = gen_graphs_from_etrecord(etrecord=self._etrecord)
debug_handle_to_op_node_map = {}
create_debug_handle_to_op_node_mapping(
self._op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
debug_handle_to_op_node_map,
)

Expand Down Expand Up @@ -479,13 +483,22 @@ def write_tensorboard_artifact(self, path: str) -> None:
# TODO: implement
pass

def get_exported_program(self, graph: Optional[str] = None) -> ExportedProgram:
def get_exported_program(
self, graph: Optional[str] = None
) -> Optional[ExportedProgram]:
"""
Access helper for ETRecord, defaults to returning Edge Dialect Program

Args:
graph: Name of the graph to access. If None, returns the Edge Dialect Program.
"""
if graph is None:
return self._etrecord.edge_dialect_program
return self._etrecord.graph_map.get(graph)
if self._etrecord is None:
log.warning(
"Exported program is only available when a valid etrecord_path was provided at the time of Inspector construction"
)
return None
return (
self._etrecord.edge_dialect_program
if graph is None
else self._etrecord.graph_map.get(graph)
)
25 changes: 9 additions & 16 deletions sdk/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_event_block_to_dataframe(self) -> None:
def test_inspector_constructor(self):
# Create a context manager to patch functions called by Inspector.__init__
with patch.object(
inspector, "gen_etrecord_object", return_value=None
) as mock_gen_etrecord, patch.object(
inspector, "parse_etrecord", return_value=None
) as mock_parse_etrecord, patch.object(
inspector, "gen_etdump_object", return_value=None
) as mock_gen_etdump, patch.object(
EventBlock, "_gen_from_etdump"
Expand All @@ -69,20 +69,17 @@ def test_inspector_constructor(self):
)

# Assert that expected functions are called
mock_gen_etrecord.assert_called_once_with(etrecord_path=ETRECORD_PATH)
mock_parse_etrecord.assert_called_once_with(etrecord_path=ETRECORD_PATH)
mock_gen_etdump.assert_called_once_with(etdump_path=ETDUMP_PATH)
mock_gen_from_etdump.assert_called_once()
mock_gen_graphs_from_etrecord.assert_called_once()
# Because we mocked parse_etrecord() to return None, this method shouldn't be called
mock_gen_graphs_from_etrecord.assert_not_called()

def test_inspector_get_event_blocks_and_print_data_tabular(self):
# Create a context manager to patch functions called by Inspector.__init__
with patch.object(
inspector, "gen_etrecord_object", return_value=None
), patch.object(
with patch.object(inspector, "parse_etrecord", return_value=None), patch.object(
inspector, "gen_etdump_object", return_value=None
), patch.object(
EventBlock, "_gen_from_etdump"
), patch.object(
), patch.object(EventBlock, "_gen_from_etdump"), patch.object(
inspector, "gen_graphs_from_etrecord"
):
# Call the constructor of Inspector
Expand Down Expand Up @@ -189,13 +186,9 @@ def test_inspector_associate_with_op_graph_nodes_multiple_debug_handles(self):

def test_inspector_get_exported_program(self):
# Create a context manager to patch functions called by Inspector.__init__
with patch.object(
inspector, "gen_etrecord_object", return_value=None
), patch.object(
with patch.object(inspector, "parse_etrecord", return_value=None), patch.object(
inspector, "gen_etdump_object", return_value=None
), patch.object(
EventBlock, "_gen_from_etdump"
), patch.object(
), patch.object(EventBlock, "_gen_from_etdump"), patch.object(
inspector, "gen_graphs_from_etrecord"
), patch.object(
inspector, "create_debug_handle_to_op_node_mapping"
Expand Down