Skip to content

Commit 7650667

Browse files
authored
Add a default delegate time scale converter
Differential Revision: D62160650 Pull Request resolved: #5076
1 parent 59d9bad commit 7650667

File tree

4 files changed

+78
-9
lines changed

4 files changed

+78
-9
lines changed

devtools/inspector/_inspector.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import dataclasses
810
import logging
911
import sys
@@ -39,6 +41,7 @@
3941
)
4042
from executorch.devtools.etrecord import ETRecord, parse_etrecord
4143
from executorch.devtools.inspector._inspector_utils import (
44+
calculate_time_scale_factor,
4245
create_debug_handle_to_op_node_mapping,
4346
EDGE_DIALECT_GRAPH_KEY,
4447
EXCLUDED_COLUMNS_WHEN_PRINTING,
@@ -52,7 +55,6 @@
5255
is_inference_output_equal,
5356
ProgramOutput,
5457
RESERVED_FRAMEWORK_EVENT_NAMES,
55-
TIME_SCALE_DICT,
5658
TimeScale,
5759
verify_debug_data_equivalence,
5860
)
@@ -799,9 +801,7 @@ class GroupedRunInstances:
799801

800802
# Construct the EventBlocks
801803
event_blocks = []
802-
scale_factor = (
803-
TIME_SCALE_DICT[source_time_scale] / TIME_SCALE_DICT[target_time_scale]
804-
)
804+
scale_factor = calculate_time_scale_factor(source_time_scale, target_time_scale)
805805
for run_signature, grouped_run_instance in run_groups.items():
806806
run_group: OrderedDict[EventSignature, List[InstructionEvent]] = (
807807
grouped_run_instance.events
@@ -966,6 +966,9 @@ def __init__(
966966
debug_buffer_path: Debug buffer file path that contains the debug data referenced by ETDump for intermediate and program outputs.
967967
delegate_metadata_parser: Optional function to parse delegate metadata from an Profiling Event. Expected signature of the function is:
968968
(delegate_metadata_list: List[bytes]) -> Union[List[str], Dict[str, Any]]
969+
delegate_time_scale_converter: Optional function to convert the time scale of delegate profiling data. If not given, use the conversion ratio of
970+
target_time_scale/source_time_scale.
971+
enable_module_hierarchy: Enable submodules in the operator graph. Defaults to False.
969972
970973
Returns:
971974
None
@@ -980,6 +983,14 @@ def __init__(
980983
self._source_time_scale = source_time_scale
981984
self._target_time_scale = target_time_scale
982985

986+
if delegate_time_scale_converter is None:
987+
scale_factor = calculate_time_scale_factor(
988+
source_time_scale, target_time_scale
989+
)
990+
delegate_time_scale_converter = (
991+
lambda event_name, input_time: input_time / scale_factor
992+
)
993+
983994
if etrecord is None:
984995
self._etrecord = None
985996
elif isinstance(etrecord, ETRecord):
@@ -1002,10 +1013,10 @@ def __init__(
10021013
)
10031014

10041015
self.event_blocks = EventBlock._gen_from_etdump(
1005-
etdump,
1006-
self._source_time_scale,
1007-
self._target_time_scale,
1008-
output_buffer,
1016+
etdump=etdump,
1017+
source_time_scale=self._source_time_scale,
1018+
target_time_scale=self._target_time_scale,
1019+
output_buffer=output_buffer,
10091020
delegate_metadata_parser=delegate_metadata_parser,
10101021
delegate_time_scale_converter=delegate_time_scale_converter,
10111022
)

devtools/inspector/_inspector_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import math
810
from enum import Enum
911
from typing import Dict, List, Mapping, Optional, Tuple, TypeAlias, Union
@@ -63,6 +65,15 @@ class TimeScale(Enum):
6365
}
6466

6567

68+
def calculate_time_scale_factor(
69+
source_time_scale: TimeScale, target_time_scale: TimeScale
70+
) -> float:
71+
"""
72+
Calculate the factor (source divided by target) between two time scales
73+
"""
74+
return TIME_SCALE_DICT[source_time_scale] / TIME_SCALE_DICT[target_time_scale]
75+
76+
6677
# Model Debug Output
6778
InferenceOutput: TypeAlias = Union[
6879
torch.Tensor, List[torch.Tensor], int, float, str, bool, None

devtools/inspector/tests/inspector_test.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import random
810
import statistics
911
import tempfile
1012
import unittest
1113
from contextlib import redirect_stdout
1214

13-
from typing import List
15+
from typing import Callable, List
1416

1517
from unittest.mock import patch
1618

@@ -32,6 +34,7 @@
3234
InstructionEvent,
3335
InstructionEventSignature,
3436
ProfileEventSignature,
37+
TimeScale,
3538
)
3639

3740
from executorch.exir import ExportedProgram
@@ -88,6 +91,33 @@ def test_inspector_constructor(self):
8891
# Because we mocked parse_etrecord() to return None, this method shouldn't be called
8992
mock_gen_graphs_from_etrecord.assert_not_called()
9093

94+
def test_default_delegate_time_scale_converter(self):
95+
# Create a context manager to patch functions called by Inspector.__init__
96+
with patch.object(
97+
_inspector, "parse_etrecord", return_value=None
98+
), patch.object(
99+
_inspector, "gen_etdump_object", return_value=None
100+
), patch.object(
101+
EventBlock, "_gen_from_etdump"
102+
) as mock_gen_from_etdump, patch.object(
103+
_inspector, "gen_graphs_from_etrecord"
104+
), patch.object(
105+
_inspector, "create_debug_handle_to_op_node_mapping"
106+
):
107+
# Call the constructor of Inspector
108+
Inspector(
109+
etdump_path=ETDUMP_PATH,
110+
etrecord=ETRECORD_PATH,
111+
source_time_scale=TimeScale.US,
112+
target_time_scale=TimeScale.S,
113+
)
114+
115+
# Verify delegate_time_scale_converter is set to be a callable
116+
self.assertIsInstance(
117+
mock_gen_from_etdump.call_args.get("delegate_time_scale_converter"),
118+
Callable,
119+
)
120+
91121
def test_inspector_print_data_tabular(self):
92122
# Create a context manager to patch functions called by Inspector.__init__
93123
with patch.object(

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import tempfile
810
import unittest
911
from typing import Dict, Tuple
@@ -23,11 +25,13 @@
2325

2426
from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord
2527
from executorch.devtools.inspector._inspector_utils import (
28+
calculate_time_scale_factor,
2629
create_debug_handle_to_op_node_mapping,
2730
EDGE_DIALECT_GRAPH_KEY,
2831
find_populated_event,
2932
gen_graphs_from_etrecord,
3033
is_inference_output_equal,
34+
TimeScale,
3135
)
3236

3337

@@ -170,6 +174,19 @@ def test_is_inference_output_equal_returns_true_for_same_strs(self):
170174
)
171175
)
172176

177+
def test_calculate_time_scale_factor_second_based(self):
178+
self.assertEqual(
179+
calculate_time_scale_factor(TimeScale.NS, TimeScale.MS), 1000000
180+
)
181+
self.assertEqual(
182+
calculate_time_scale_factor(TimeScale.MS, TimeScale.NS), 1 / 1000000
183+
)
184+
185+
def test_calculate_time_scale_factor_cycles(self):
186+
self.assertEqual(
187+
calculate_time_scale_factor(TimeScale.CYCLES, TimeScale.CYCLES), 1
188+
)
189+
173190

174191
def gen_mock_operator_graph_with_expected_map() -> (
175192
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)