Skip to content

Commit ea8b1ef

Browse files
Juntian777facebook-github-bot
authored andcommitted
Merge intermediate output with overlapped debug_handles (#11529)
Summary: Pull Request resolved: #11529 This PR added the function merge_overlapping_debug_handles to merge the overlapping debug handles in the intermediate outputs mapping, ensuring that only the last intermediate output is retained for each overlap. It will be called later when do the mapping between aot_intermediate_outputs and runtime_intermediate_outputs as the first step to do pre-processing. Reviewed By: larryliu0820 Differential Revision: D76304528
1 parent d719e8e commit ea8b1ef

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import math
1010
import sys
1111
from enum import Enum
12-
from typing import Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
12+
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
1313

1414
import executorch.devtools.etdump.schema_flatcc as flatcc
1515

@@ -483,3 +483,32 @@ def compare_results(
483483
print("\n")
484484

485485
return results
486+
487+
488+
def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...], Any]):
489+
"""
490+
Merge overlapping debug handles int a single key
491+
"""
492+
if not intermediate_outputs:
493+
return
494+
# Extract and normalize into (start, end, val)
495+
intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()]
496+
intervals.sort(key=lambda x: x[0])
497+
498+
# Merge overlapping debug_hanldes, picking the last value
499+
merged_intermediate_outputs = []
500+
cur_start, cur_end, cur_val = intervals[0]
501+
for start, end, val in intervals[1:]:
502+
if start <= cur_end: # Overlaps
503+
if end > cur_end: # Extend if this one goes further
504+
cur_end, cur_val = end, val
505+
506+
else:
507+
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
508+
cur_start, cur_end, cur_val = start, end, val
509+
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
510+
511+
# Clear original one and populate with merged keys (value will point to the same object)
512+
intermediate_outputs.clear()
513+
for start, end, val in merged_intermediate_outputs:
514+
intermediate_outputs[tuple(range(start, end + 1))] = val

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
find_populated_event,
3535
gen_graphs_from_etrecord,
3636
is_inference_output_equal,
37+
merge_overlapping_debug_handles,
3738
TimeScale,
3839
)
3940

@@ -217,6 +218,26 @@ def test_compare_results_uint8(self):
217218
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
218219
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)
219220

221+
def test_merge_overlapping_debug_handles(self):
222+
big_tensor = torch.rand(100, 100)
223+
intermediate_outputs = {
224+
(1, 2, 3): "val1",
225+
(2, 3, 4, 5): "val2",
226+
(6, 7, 8): "val3",
227+
(10, 11): "val4",
228+
(11, 12): big_tensor,
229+
}
230+
# basic merge behavior
231+
merge_overlapping_debug_handles(intermediate_outputs)
232+
expected_intermediate_outputs = {
233+
(1, 2, 3, 4, 5): "val2",
234+
(6, 7, 8): "val3",
235+
(10, 11, 12): big_tensor,
236+
}
237+
238+
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
239+
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)
240+
220241

221242
def gen_mock_operator_graph_with_expected_map() -> (
222243
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)