Skip to content

Commit 0d244f9

Browse files
authored
Merge intermediate output with overlapped debug_handles
Differential Revision: D76304528 Pull Request resolved: #11529
1 parent 5a45132 commit 0d244f9

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import math
1010
import sys
1111
from enum import Enum
12-
from typing import Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
12+
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
1313

1414
import executorch.devtools.etdump.schema_flatcc as flatcc
1515

@@ -483,3 +483,32 @@ def compare_results(
483483
print("\n")
484484

485485
return results
486+
487+
488+
def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...], Any]):
489+
"""
490+
Merge overlapping debug handles int a single key
491+
"""
492+
if not intermediate_outputs:
493+
return
494+
# Extract and normalize into (start, end, val)
495+
intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()]
496+
intervals.sort(key=lambda x: x[0])
497+
498+
# Merge overlapping debug_hanldes, picking the last value
499+
merged_intermediate_outputs = []
500+
cur_start, cur_end, cur_val = intervals[0]
501+
for start, end, val in intervals[1:]:
502+
if start <= cur_end: # Overlaps
503+
if end > cur_end: # Extend if this one goes further
504+
cur_end, cur_val = end, val
505+
506+
else:
507+
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
508+
cur_start, cur_end, cur_val = start, end, val
509+
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
510+
511+
# Clear original one and populate with merged keys (value will point to the same object)
512+
intermediate_outputs.clear()
513+
for start, end, val in merged_intermediate_outputs:
514+
intermediate_outputs[tuple(range(start, end + 1))] = val

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
find_populated_event,
3535
gen_graphs_from_etrecord,
3636
is_inference_output_equal,
37+
merge_overlapping_debug_handles,
3738
TimeScale,
3839
)
3940

@@ -217,6 +218,26 @@ def test_compare_results_uint8(self):
217218
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
218219
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)
219220

221+
def test_merge_overlapping_debug_handles(self):
222+
big_tensor = torch.rand(100, 100)
223+
intermediate_outputs = {
224+
(1, 2, 3): "val1",
225+
(2, 3, 4, 5): "val2",
226+
(6, 7, 8): "val3",
227+
(10, 11): "val4",
228+
(11, 12): big_tensor,
229+
}
230+
# basic merge behavior
231+
merge_overlapping_debug_handles(intermediate_outputs)
232+
expected_intermediate_outputs = {
233+
(1, 2, 3, 4, 5): "val2",
234+
(6, 7, 8): "val3",
235+
(10, 11, 12): big_tensor,
236+
}
237+
238+
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
239+
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)
240+
220241

221242
def gen_mock_operator_graph_with_expected_map() -> (
222243
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)