Skip to content

Commit b4bd968

Browse files
Juntian777facebook-github-bot
authored andcommitted
Merge intermediate output with overlapped debug_handles
Summary: The function merge_overlapping_debug_handles was added to merge the overlapping debug handles in the intermediate outputs mapping, ensuring that only the last intermediate output is retained for each overlap. Differential Revision: D76304528
1 parent 3f11883 commit b4bd968

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import math
1010
import sys
1111
from enum import Enum
12-
from typing import Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
12+
from typing import Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union, Any
1313

1414
import executorch.devtools.etdump.schema_flatcc as flatcc
1515

@@ -483,3 +483,30 @@ def compare_results(
483483
print("\n")
484484

485485
return results
486+
487+
def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...], Any]):
488+
"""
489+
Merge overlapping debug handles int a single key
490+
"""
491+
492+
# Extract and normalize into (start, end, val)
493+
intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()]
494+
intervals.sort(key=lambda x: x[0])
495+
496+
# Merge overlapping debug_hanldes, picking the last value
497+
merged_intermediate_outputs = []
498+
cur_start, cur_end, cur_val = intervals[0]
499+
for start, end, val in intervals[1:]:
500+
if start <= cur_end: # Overlaps
501+
if end > cur_end: # Extend if this one goes further
502+
cur_end, cur_val = end, val
503+
504+
else:
505+
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
506+
cur_start, cur_end, cur_val = start, end, val
507+
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
508+
509+
# Clear original one and populate with merged keys (value will point to the same object)
510+
intermediate_outputs.clear()
511+
for start, end, val in merged_intermediate_outputs:
512+
intermediate_outputs[tuple(range(start, end + 1))] = val

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
gen_graphs_from_etrecord,
3636
is_inference_output_equal,
3737
TimeScale,
38+
merge_overlapping_debug_handles,
3839
)
3940

4041

@@ -217,6 +218,25 @@ def test_compare_results_uint8(self):
217218
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
218219
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)
219220

221+
def test_merge_overlapping_debug_handles(self):
222+
big_tensor = torch.rand(100, 100)
223+
intermediate_outputs = {
224+
(1, 2, 3) : "val1",
225+
(2, 3, 4, 5) : "val2",
226+
(6, 7, 8) : "val3",
227+
(10, 11): "val4",
228+
(11, 12): big_tensor,
229+
}
230+
# basic merge behavior
231+
merge_overlapping_debug_handles(intermediate_outputs)
232+
expected_intermediate_outputs = {
233+
(1, 2, 3, 4, 5) : "val2",
234+
(6, 7, 8) : "val3",
235+
(10, 11, 12): big_tensor,
236+
}
237+
238+
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
239+
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)
220240

221241
def gen_mock_operator_graph_with_expected_map() -> (
222242
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)