@@ -73,6 +73,8 @@ class TimeScale(Enum):
73
73
TimeScale .CYCLES : 1 ,
74
74
}
75
75
76
+ DebugHandle : TypeAlias = Tuple [int , ...]
77
+
76
78
77
79
class NodeSource (Enum ):
78
80
AOT = 1
@@ -528,7 +530,7 @@ def compare_results(
528
530
return results
529
531
530
532
531
- def merge_overlapping_debug_handles (intermediate_outputs : Dict [Tuple [ int , ...] , Any ]):
533
+ def merge_overlapping_debug_handles (intermediate_outputs : Dict [DebugHandle , Any ]):
532
534
"""
533
535
Merge overlapping debug handles int a single key
534
536
"""
@@ -558,7 +560,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
558
560
559
561
560
562
def _debug_handles_have_overlap (
561
- aot_debug_hanlde : Tuple [ int , ...], runtime_debug_handle : Tuple [ int , ...]
563
+ aot_debug_hanlde : DebugHandle , runtime_debug_handle : DebugHandle
562
564
) -> bool :
563
565
"""
564
566
Check if the AOT debug handle and the runtime debug handle have any overlap.
@@ -568,7 +570,7 @@ def _debug_handles_have_overlap(
568
570
return len (aot_set .intersection (runtime_set )) > 0
569
571
570
572
571
- def _combine_debug_hanldes (debug_handles : List [Tuple [ int , ...]] ) -> Tuple [ int , ...] :
573
+ def _combine_debug_hanldes (debug_handles : List [DebugHandle ] ) -> DebugHandle :
572
574
"""Combine multiple debug handles into one debug handle"""
573
575
combined_debug_handles_set = set ()
574
576
for debug_handle in debug_handles :
@@ -577,8 +579,8 @@ def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, .
577
579
578
580
579
581
def _combine_overlapped_intermediate_outputs (
580
- nodes : List [Tuple [Tuple [ int , ...] , Any ]]
581
- ) -> Tuple [Tuple [ int , ...] , Any ]:
582
+ nodes : List [Tuple [DebugHandle , Any ]]
583
+ ) -> Tuple [DebugHandle , Any ]:
582
584
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
583
585
debug_handles = [debug_handle for debug_handle , _ in nodes ]
584
586
outputs = [output for _ , output in nodes ]
@@ -588,8 +590,8 @@ def _combine_overlapped_intermediate_outputs(
588
590
589
591
590
592
def _create_debug_handle_overlap_graph (
591
- aot_intermediate_outputs : Dict [Tuple [ int , ...] , Any ],
592
- runtime_intermediate_outputs : Dict [Tuple [ int , ...] , Any ],
593
+ aot_intermediate_outputs : Dict [DebugHandle , Any ],
594
+ runtime_intermediate_outputs : Dict [DebugHandle , Any ],
593
595
) -> Tuple [List [NodeData ], Dict [int , List [int ]]]:
594
596
"""
595
597
Create a graph representing overlapping debug handles between AOT and runtime outputs.
@@ -659,15 +661,15 @@ def dfs(node_id, component):
659
661
660
662
661
663
def map_runtime_aot_intermediate_outputs (
662
- aot_intermediate_outputs : Dict [Tuple [ int , ...] , Any ],
663
- runtime_intermediate_outputs : Dict [Tuple [ int , ...] , Any ],
664
- ) -> Dict [Tuple [Tuple [ int , ...], Any ], Tuple [Tuple [ int , ...] , Any ]]:
664
+ aot_intermediate_outputs : Dict [DebugHandle , Any ],
665
+ runtime_intermediate_outputs : Dict [DebugHandle , Any ],
666
+ ) -> Dict [Tuple [DebugHandle , Any ], Tuple [DebugHandle , Any ]]:
665
667
"""
666
668
Map the runtime intermediate outputs to the AOT intermediate outputs
667
669
by finding overlapping debug handles and combining them into a single debug_handle
668
670
669
671
Returns:
670
- Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...] , Any]] - Mapping
672
+ Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle , Any]] - Mapping
671
673
from runtime intermediate output to AOT intermediate output
672
674
"""
673
675
# Merge overlapping debug handles
@@ -760,13 +762,13 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
760
762
761
763
def get_aot_debug_handle_to_op_name_mapping (
762
764
graph_module : torch .fx .GraphModule ,
763
- ) -> Dict [Tuple [ int , ...] , str ]:
765
+ ) -> Dict [DebugHandle , str ]:
764
766
"""
765
767
Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module.
766
768
Parameters:
767
769
graph_module (torch.fx.GraphModule): The graph module to get the mapping from.
768
770
Returns:
769
- Dict[Tuple[int, ...] , str]: A dictionary mapping debug handles to operator names.
771
+ Dict[DebugHandle , str]: A dictionary mapping debug handles to operator names.
770
772
"""
771
773
node_filters = [
772
774
NodeFilter ("debug_handle" , "call_function" , exclude_ops = ["getitem" ])
@@ -787,8 +789,8 @@ def get_aot_debug_handle_to_op_name_mapping(
787
789
788
790
789
791
def find_op_names (
790
- target_debug_handle : Tuple [ int , ...] ,
791
- debug_handle_to_op_name : Dict [Tuple [ int , ...] , str ],
792
+ target_debug_handle : DebugHandle ,
793
+ debug_handle_to_op_name : Dict [DebugHandle , str ],
792
794
) -> List [str ]:
793
795
"""
794
796
Record the operator names only if their debug handles are part of the target debug handle.
0 commit comments