|
9 | 9 | import math
|
10 | 10 | import sys
|
11 | 11 | from enum import Enum
|
12 |
| -from typing import Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union |
| 12 | +from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union |
13 | 13 |
|
14 | 14 | import executorch.devtools.etdump.schema_flatcc as flatcc
|
15 | 15 |
|
@@ -483,3 +483,32 @@ def compare_results(
|
483 | 483 | print("\n")
|
484 | 484 |
|
485 | 485 | return results
|
| 486 | + |
| 487 | + |
| 488 | +def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...], Any]): |
| 489 | + """ |
| 490 | + Merge overlapping debug handles int a single key |
| 491 | + """ |
| 492 | + if not intermediate_outputs: |
| 493 | + return |
| 494 | + # Extract and normalize into (start, end, val) |
| 495 | + intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()] |
| 496 | + intervals.sort(key=lambda x: x[0]) |
| 497 | + |
| 498 | + # Merge overlapping debug_hanldes, picking the last value |
| 499 | + merged_intermediate_outputs = [] |
| 500 | + cur_start, cur_end, cur_val = intervals[0] |
| 501 | + for start, end, val in intervals[1:]: |
| 502 | + if start <= cur_end: # Overlaps |
| 503 | + if end > cur_end: # Extend if this one goes further |
| 504 | + cur_end, cur_val = end, val |
| 505 | + |
| 506 | + else: |
| 507 | + merged_intermediate_outputs.append((cur_start, cur_end, cur_val)) |
| 508 | + cur_start, cur_end, cur_val = start, end, val |
| 509 | + merged_intermediate_outputs.append((cur_start, cur_end, cur_val)) |
| 510 | + |
| 511 | + # Clear original one and populate with merged keys (value will point to the same object) |
| 512 | + intermediate_outputs.clear() |
| 513 | + for start, end, val in merged_intermediate_outputs: |
| 514 | + intermediate_outputs[tuple(range(start, end + 1))] = val |
0 commit comments