Skip to content

Commit 6686df3

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
make etrecord optional
Differential Revision: D49844340 fbshipit-source-id: 2036161b590dabec6e87e7e068046bc7cec2f3e4
1 parent 7906c18 commit 6686df3

File tree

4 files changed

+37
-37
lines changed

4 files changed

+37
-37
lines changed

sdk/inspector/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ python_library(
1818
"//executorch/exir:lib",
1919
"//executorch/sdk/edir:et_schema",
2020
"//executorch/sdk/etdump:schema_flatcc",
21+
"//executorch/sdk/etrecord:etrecord",
2122
],
2223
)
2324

sdk/inspector/_inspector_utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC
1515

1616
from executorch.sdk.etdump.serialize import deserialize_from_etdump_flatcc
17-
from executorch.sdk.etrecord import ETRecord, parse_etrecord
17+
from executorch.sdk.etrecord import ETRecord
1818

1919
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
2020

@@ -63,13 +63,6 @@ def create_debug_handle_to_op_node_mapping(
6363
debug_handle_to_op_node_map[debug_handle] = element
6464

6565

66-
def gen_etrecord_object(etrecord_path: Optional[str] = None) -> ETRecord:
67-
# Gen op graphs from etrecord
68-
if etrecord_path is None:
69-
raise ValueError("Etrecord_path must be specified.")
70-
return parse_etrecord(etrecord_path=etrecord_path)
71-
72-
7366
def gen_etdump_object(etdump_path: Optional[str] = None) -> ETDumpFlatCC:
7467
# Gen event blocks from etdump
7568
if etdump_path is None:

sdk/inspector/inspector.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
import torch
2626
from executorch.exir import ExportedProgram
2727

28-
from executorch.sdk.edir.et_schema import OperatorGraphWithStats, OperatorNode
28+
from executorch.sdk.edir.et_schema import OperatorNode
2929
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC, ProfileEvent
30+
from executorch.sdk.etrecord import parse_etrecord
3031
from executorch.sdk.inspector._inspector_utils import (
3132
create_debug_handle_to_op_node_mapping,
3233
EDGE_DIALECT_GRAPH_KEY,
3334
gen_etdump_object,
34-
gen_etrecord_object,
3535
gen_graphs_from_etrecord,
3636
)
3737

@@ -368,7 +368,6 @@ class Inspector:
368368
369369
Private Attributes:
370370
_etrecord: Optional[ETRecord]. File under etrecord_path deserialized into an object.
371-
_op_graph_dict: Mapping[str, OperatorGraphWithStats]. Graph objects parsed from etrecord matched with user defined graph names.
372371
"""
373372

374373
def __init__(
@@ -387,14 +386,18 @@ def __init__(
387386
defaults to milli (1000ms = 1s).
388387
"""
389388

390-
# TODO: etrecord_path can be optional, so need to support the case when it is not present
391-
self._etrecord = gen_etrecord_object(etrecord_path=etrecord_path)
389+
self._etrecord = (
390+
parse_etrecord(etrecord_path=etrecord_path)
391+
if etrecord_path is not None
392+
else None
393+
)
394+
392395
etdump = gen_etdump_object(etdump_path=etdump_path)
393396
self.event_blocks = EventBlock._gen_from_etdump(etdump, etdump_scale)
394397

395-
self._op_graph_dict: Mapping[
396-
str, OperatorGraphWithStats
397-
] = gen_graphs_from_etrecord(etrecord=self._etrecord)
398+
# No additional data association can be done without ETRecord, so return early
399+
if self._etrecord is None:
400+
return
398401

399402
# Use the delegate map from etrecord, associate debug handles with each event
400403
for event_block in self.event_blocks:
@@ -406,9 +409,10 @@ def __init__(
406409
)
407410

408411
# Traverse the edge dialect op graph to create mapping from debug_handle to op node
412+
op_graph_dict = gen_graphs_from_etrecord(etrecord=self._etrecord)
409413
debug_handle_to_op_node_map = {}
410414
create_debug_handle_to_op_node_mapping(
411-
self._op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
415+
op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
412416
debug_handle_to_op_node_map,
413417
)
414418

@@ -479,13 +483,22 @@ def write_tensorboard_artifact(self, path: str) -> None:
479483
# TODO: implement
480484
pass
481485

482-
def get_exported_program(self, graph: Optional[str] = None) -> ExportedProgram:
486+
def get_exported_program(
487+
self, graph: Optional[str] = None
488+
) -> Optional[ExportedProgram]:
483489
"""
484490
Access helper for ETRecord, defaults to returning Edge Dialect Program
485491
486492
Args:
487493
graph: Name of the graph to access. If None, returns the Edge Dialect Program.
488494
"""
489-
if graph is None:
490-
return self._etrecord.edge_dialect_program
491-
return self._etrecord.graph_map.get(graph)
495+
if self._etrecord is None:
496+
log.warning(
497+
"Exported program is only available when a valid etrecord_path was provided at the time of Inspector construction"
498+
)
499+
return None
500+
return (
501+
self._etrecord.edge_dialect_program
502+
if graph is None
503+
else self._etrecord.graph_map.get(graph)
504+
)

sdk/inspector/tests/inspector_test.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def test_event_block_to_dataframe(self) -> None:
5454
def test_inspector_constructor(self):
5555
# Create a context manager to patch functions called by Inspector.__init__
5656
with patch.object(
57-
inspector, "gen_etrecord_object", return_value=None
58-
) as mock_gen_etrecord, patch.object(
57+
inspector, "parse_etrecord", return_value=None
58+
) as mock_parse_etrecord, patch.object(
5959
inspector, "gen_etdump_object", return_value=None
6060
) as mock_gen_etdump, patch.object(
6161
EventBlock, "_gen_from_etdump"
@@ -69,20 +69,17 @@ def test_inspector_constructor(self):
6969
)
7070

7171
# Assert that expected functions are called
72-
mock_gen_etrecord.assert_called_once_with(etrecord_path=ETRECORD_PATH)
72+
mock_parse_etrecord.assert_called_once_with(etrecord_path=ETRECORD_PATH)
7373
mock_gen_etdump.assert_called_once_with(etdump_path=ETDUMP_PATH)
7474
mock_gen_from_etdump.assert_called_once()
75-
mock_gen_graphs_from_etrecord.assert_called_once()
75+
# Because we mocked parse_etrecord() to return None, this method shouldn't be called
76+
mock_gen_graphs_from_etrecord.assert_not_called()
7677

7778
def test_inspector_get_event_blocks_and_print_data_tabular(self):
7879
# Create a context manager to patch functions called by Inspector.__init__
79-
with patch.object(
80-
inspector, "gen_etrecord_object", return_value=None
81-
), patch.object(
80+
with patch.object(inspector, "parse_etrecord", return_value=None), patch.object(
8281
inspector, "gen_etdump_object", return_value=None
83-
), patch.object(
84-
EventBlock, "_gen_from_etdump"
85-
), patch.object(
82+
), patch.object(EventBlock, "_gen_from_etdump"), patch.object(
8683
inspector, "gen_graphs_from_etrecord"
8784
):
8885
# Call the constructor of Inspector
@@ -189,13 +186,9 @@ def test_inspector_associate_with_op_graph_nodes_multiple_debug_handles(self):
189186

190187
def test_inspector_get_exported_program(self):
191188
# Create a context manager to patch functions called by Inspector.__init__
192-
with patch.object(
193-
inspector, "gen_etrecord_object", return_value=None
194-
), patch.object(
189+
with patch.object(inspector, "parse_etrecord", return_value=None), patch.object(
195190
inspector, "gen_etdump_object", return_value=None
196-
), patch.object(
197-
EventBlock, "_gen_from_etdump"
198-
), patch.object(
191+
), patch.object(EventBlock, "_gen_from_etdump"), patch.object(
199192
inspector, "gen_graphs_from_etrecord"
200193
), patch.object(
201194
inspector, "create_debug_handle_to_op_node_mapping"

0 commit comments

Comments
 (0)