Skip to content

Commit 02e5c58

Browse files
authored
Integrate the IntermediateOutputCapturer into Inspector
Differential Revision: D75828351 Pull Request resolved: #11384
1 parent 3dd59f2 commit 02e5c58

File tree

3 files changed

+100
-2
lines changed

3 files changed

+100
-2
lines changed

devtools/inspector/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ python_library(
1818
"//executorch/devtools/etdump:schema_flatcc",
1919
"//executorch/devtools/etrecord:etrecord",
2020
"//executorch/exir:lib",
21+
"//executorch/devtools/inspector:intermediate_output_capturer",
2122
],
2223
)
2324

devtools/inspector/_inspector.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
TimeScale,
6060
verify_debug_data_equivalence,
6161
)
62+
from executorch.devtools.inspector._intermediate_output_capturer import (
63+
IntermediateOutputCapturer,
64+
)
6265
from executorch.exir import ExportedProgram
6366

6467

@@ -1074,6 +1077,7 @@ def __init__(
10741077
# Key str is method name; value is list of ProgramOutputs because of list of test cases
10751078
self._reference_outputs: Dict[str, List[ProgramOutput]] = {}
10761079
self._enable_module_hierarchy = enable_module_hierarchy
1080+
self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None
10771081
self._consume_etrecord()
10781082

10791083
def _consume_etrecord(self) -> None:
@@ -1134,6 +1138,16 @@ def _consume_etrecord(self) -> None:
11341138
event_block.reference_output = self._reference_outputs[FORWARD][
11351139
index
11361140
]
1141+
# Capture intermediate outputs only if _representative_inputs are provided
1142+
# when using bundled program to create the etrecord
1143+
if self._etrecord._representative_inputs is None:
1144+
return
1145+
export_program = self._etrecord.edge_dialect_program
1146+
graph_module = export_program.module()
1147+
capturer = IntermediateOutputCapturer(graph_module)
1148+
self._aot_intermediate_outputs = capturer.run_and_capture(
1149+
self._etrecord._representative_inputs
1150+
)
11371151

11381152
def to_dataframe(
11391153
self,

devtools/inspector/tests/inspector_test.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88

9+
import copy
910
import random
1011
import statistics
1112
import tempfile
@@ -16,9 +17,13 @@
1617

1718
from unittest.mock import patch
1819

20+
import torch
21+
import torch.fx
22+
1923
from executorch.devtools import generate_etrecord, parse_etrecord
2024
from executorch.devtools.debug_format.et_schema import OperatorNode
2125
from executorch.devtools.etdump.schema_flatcc import ProfileEvent
26+
from executorch.devtools.etrecord._etrecord import ETRecord
2227
from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord
2328

2429
from executorch.devtools.inspector import (
@@ -36,8 +41,17 @@
3641
ProfileEventSignature,
3742
TimeScale,
3843
)
39-
40-
from executorch.exir import ExportedProgram
44+
from executorch.devtools.inspector.tests.inspector_test_utils import (
45+
check_if_final_outputs_match,
46+
model_registry,
47+
)
48+
from executorch.exir import (
49+
EdgeCompileConfig,
50+
EdgeProgramManager,
51+
ExecutorchProgramManager,
52+
to_edge,
53+
)
54+
from torch.export import export, ExportedProgram
4155

4256

4357
OP_TYPE = "aten::add"
@@ -452,6 +466,75 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self):
452466
events=events,
453467
)
454468

469+
def test_no_capture_when_representative_inputs_are_none(self):
470+
# Create a context manager to patch functions called by Inspector.__init__
471+
with patch.object(
472+
_inspector, "parse_etrecord", return_value=None
473+
), patch.object(
474+
_inspector, "gen_etdump_object", return_value=None
475+
), patch.object(
476+
EventBlock, "_gen_from_etdump"
477+
), patch.object(
478+
_inspector, "gen_graphs_from_etrecord"
479+
):
480+
# Call the constructor of Inspector
481+
inspector_instance = Inspector(
482+
etdump_path=ETDUMP_PATH,
483+
etrecord=ETRECORD_PATH,
484+
)
485+
self.assertIsNone(inspector_instance._aot_intermediate_outputs)
486+
487+
def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
488+
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
489+
etrecord_path = tmp_file.name
490+
mod = model_registry["ConvLinearModel"]()
491+
input_tensor = torch.tensor(
492+
[[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True
493+
)
494+
aten_model: ExportedProgram = export(mod, (input_tensor,), strict=True)
495+
edge_program_manager: EdgeProgramManager = to_edge(
496+
aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
497+
)
498+
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
499+
et_program_manager: ExecutorchProgramManager = (
500+
edge_program_manager.to_executorch()
501+
)
502+
# Generate ETRecord
503+
generate_etrecord(
504+
etrecord_path, edge_program_manager_copy, et_program_manager
505+
)
506+
original_consume_etrecord = Inspector._consume_etrecord
507+
with patch.object(
508+
Inspector, "_consume_etrecord", return_value=None
509+
), patch.object(
510+
_inspector, "gen_etdump_object", return_value=None
511+
), patch.object(
512+
EventBlock, "_gen_from_etdump"
513+
), patch.object(
514+
_inspector, "gen_graphs_from_etrecord"
515+
):
516+
# Call the constructor of Inspector
517+
inspector_instance = Inspector(
518+
etdump_path=ETDUMP_PATH,
519+
etrecord=etrecord_path,
520+
)
521+
etrecord = ETRecord(
522+
edge_dialect_program=inspector_instance._etrecord.edge_dialect_program,
523+
graph_map=inspector_instance._etrecord.graph_map,
524+
_debug_handle_map=inspector_instance._etrecord._debug_handle_map,
525+
_delegate_map=inspector_instance._etrecord._delegate_map,
526+
_reference_outputs=inspector_instance._etrecord._reference_outputs,
527+
_representative_inputs=aten_model.example_inputs[0],
528+
)
529+
inspector_instance._etrecord = etrecord
530+
Inspector._consume_etrecord = original_consume_etrecord
531+
inspector_instance._consume_etrecord()
532+
self.assertTrue(
533+
check_if_final_outputs_match(
534+
"ConvLinearModel", inspector_instance._aot_intermediate_outputs
535+
)
536+
)
537+
455538
def _gen_random_float_list(self) -> List[float]:
456539
return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]
457540

0 commit comments

Comments
 (0)