Skip to content

Commit b74c68d

Browse files
authored
Refactor debug_handles type to use DebugHandles type alias
Differential Revision: D77420162 Pull Request resolved: #12061
1 parent 32f96f6 commit b74c68d

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

devtools/inspector/_inspector.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from executorch.devtools.inspector._inspector_utils import (
4444
calculate_time_scale_factor,
4545
create_debug_handle_to_op_node_mapping,
46+
DebugHandle,
4647
display_or_print_df,
4748
EDGE_DIALECT_GRAPH_KEY,
4849
EXCLUDED_COLUMNS_WHEN_PRINTING,
@@ -262,7 +263,7 @@ class RunSignature:
262263

263264
# Typing for mapping Event.delegate_debug_identifiers to debug_handle(s)
264265
DelegateIdentifierDebugHandleMap: TypeAlias = Union[
265-
Mapping[int, Tuple[int, ...]], Mapping[str, Tuple[int, ...]]
266+
Mapping[int, DebugHandle], Mapping[str, DebugHandle]
266267
]
267268

268269
# Typing for Dict containig delegate metadata
@@ -1149,7 +1150,7 @@ def _consume_etrecord(self) -> None:
11491150

11501151
def _get_aot_intermediate_outputs_and_op_names(
11511152
self,
1152-
) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]:
1153+
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, str]]:
11531154
"""
11541155
Capture intermediate outputs only if _representative_inputs are provided
11551156
when using bundled program to create the etrecord
@@ -1170,7 +1171,7 @@ def _get_aot_intermediate_outputs_and_op_names(
11701171
# TODO: Make it more extensible to further merge overlapping debug handles
11711172
def _get_runtime_intermediate_outputs_and_op_names(
11721173
self,
1173-
) -> Tuple[Dict[Tuple[int, ...], Any], Dict[Tuple[int, ...], str]]:
1174+
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, str]]:
11741175
"""
11751176
Retrieve the runtime intermediate outputs(debug handles and intermediate values mappings)
11761177
from the event blocks, along with the corresponding debug handles and op names mapping.

devtools/inspector/_inspector_utils.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class TimeScale(Enum):
7373
TimeScale.CYCLES: 1,
7474
}
7575

76+
DebugHandle: TypeAlias = Tuple[int, ...]
77+
7678

7779
class NodeSource(Enum):
7880
AOT = 1
@@ -528,7 +530,7 @@ def compare_results(
528530
return results
529531

530532

531-
def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...], Any]):
533+
def merge_overlapping_debug_handles(intermediate_outputs: Dict[DebugHandle, Any]):
532534
"""
533535
Merge overlapping debug handles int a single key
534536
"""
@@ -558,7 +560,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
558560

559561

560562
def _debug_handles_have_overlap(
561-
aot_debug_hanlde: Tuple[int, ...], runtime_debug_handle: Tuple[int, ...]
563+
aot_debug_hanlde: DebugHandle, runtime_debug_handle: DebugHandle
562564
) -> bool:
563565
"""
564566
Check if the AOT debug handle and the runtime debug handle have any overlap.
@@ -568,7 +570,7 @@ def _debug_handles_have_overlap(
568570
return len(aot_set.intersection(runtime_set)) > 0
569571

570572

571-
def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, ...]:
573+
def _combine_debug_hanldes(debug_handles: List[DebugHandle]) -> DebugHandle:
572574
"""Combine multiple debug handles into one debug handle"""
573575
combined_debug_handles_set = set()
574576
for debug_handle in debug_handles:
@@ -577,8 +579,8 @@ def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, .
577579

578580

579581
def _combine_overlapped_intermediate_outputs(
580-
nodes: List[Tuple[Tuple[int, ...], Any]]
581-
) -> Tuple[Tuple[int, ...], Any]:
582+
nodes: List[Tuple[DebugHandle, Any]]
583+
) -> Tuple[DebugHandle, Any]:
582584
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
583585
debug_handles = [debug_handle for debug_handle, _ in nodes]
584586
outputs = [output for _, output in nodes]
@@ -588,8 +590,8 @@ def _combine_overlapped_intermediate_outputs(
588590

589591

590592
def _create_debug_handle_overlap_graph(
591-
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
592-
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
593+
aot_intermediate_outputs: Dict[DebugHandle, Any],
594+
runtime_intermediate_outputs: Dict[DebugHandle, Any],
593595
) -> Tuple[List[NodeData], Dict[int, List[int]]]:
594596
"""
595597
Create a graph representing overlapping debug handles between AOT and runtime outputs.
@@ -659,15 +661,15 @@ def dfs(node_id, component):
659661

660662

661663
def map_runtime_aot_intermediate_outputs(
662-
aot_intermediate_outputs: Dict[Tuple[int, ...], Any],
663-
runtime_intermediate_outputs: Dict[Tuple[int, ...], Any],
664-
) -> Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]]:
664+
aot_intermediate_outputs: Dict[DebugHandle, Any],
665+
runtime_intermediate_outputs: Dict[DebugHandle, Any],
666+
) -> Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]]:
665667
"""
666668
Map the runtime intermediate outputs to the AOT intermediate outputs
667669
by finding overlapping debug handles and combining them into a single debug_handle
668670
669671
Returns:
670-
Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping
672+
Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]] - Mapping
671673
from runtime intermediate output to AOT intermediate output
672674
"""
673675
# Merge overlapping debug handles
@@ -760,13 +762,13 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
760762

761763
def get_aot_debug_handle_to_op_name_mapping(
762764
graph_module: torch.fx.GraphModule,
763-
) -> Dict[Tuple[int, ...], str]:
765+
) -> Dict[DebugHandle, str]:
764766
"""
765767
Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module.
766768
Parameters:
767769
graph_module (torch.fx.GraphModule): The graph module to get the mapping from.
768770
Returns:
769-
Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names.
771+
Dict[DebugHandle, str]: A dictionary mapping debug handles to operator names.
770772
"""
771773
node_filters = [
772774
NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"])
@@ -787,8 +789,8 @@ def get_aot_debug_handle_to_op_name_mapping(
787789

788790

789791
def find_op_names(
790-
target_debug_handle: Tuple[int, ...],
791-
debug_handle_to_op_name: Dict[Tuple[int, ...], str],
792+
target_debug_handle: DebugHandle,
793+
debug_handle_to_op_name: Dict[DebugHandle, str],
792794
) -> List[str]:
793795
"""
794796
Record the operator names only if their debug handles are part of the target debug handle.

devtools/inspector/_intermediate_output_capturer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
# pyre-unsafe
88

99

10-
from typing import Any, Dict, Tuple
10+
from typing import Any, Dict
1111

1212
import torch
13-
from executorch.devtools.inspector._inspector_utils import NodeFilter
13+
from executorch.devtools.inspector._inspector_utils import DebugHandle, NodeFilter
1414
from torch.fx import GraphModule
1515
from torch.fx.interpreter import Interpreter
1616

@@ -30,7 +30,7 @@ def __init__(self, module: GraphModule):
3030
]
3131

3232
# Runs the graph module and captures the intermediate outputs.
33-
def run_and_capture(self, *args, **kwargs) -> Dict[Tuple[int, ...], Any]:
33+
def run_and_capture(self, *args, **kwargs) -> Dict[DebugHandle, Any]:
3434
captured_outputs = {}
3535

3636
def capture_run_node(n: torch.fx.Node) -> Any:

0 commit comments

Comments
 (0)