25
25
import torch
26
26
from executorch .exir import ExportedProgram
27
27
28
- from executorch .sdk .edir .et_schema import OperatorGraphWithStats
29
- from executorch .sdk .etdb ._inspector_utils import gen_graphs_from_etrecord
28
+ from executorch .sdk .edir .et_schema import OperatorGraphWithStats , OperatorNode
29
+ from executorch .sdk .etdb ._inspector_utils import (
30
+ create_debug_handle_to_op_node_mapping ,
31
+ gen_etdump_and_etrecord_objects ,
32
+ gen_graphs_from_etrecord ,
33
+ )
30
34
from executorch .sdk .etdump .schema_flatcc import ETDumpFlatCC , ProfileEvent
31
- from executorch . sdk . etrecord import parse_etrecord
35
+
32
36
from tabulate import tabulate
33
37
34
38
log : logging .Logger = logging .getLogger (__name__ )
@@ -112,7 +116,7 @@ class Event:
112
116
113
117
name : str
114
118
perf_data : PerfData
115
- op_type : List [str ] = dataclasses .field (default_factory = list )
119
+ op_types : List [str ] = dataclasses .field (default_factory = list )
116
120
117
121
# Instruction Id of the original profiling event
118
122
instruction_id : Optional [int ] = None
@@ -123,7 +127,7 @@ class Event:
123
127
# Debug Handles in the model graph to which this event is correlated
124
128
debug_handles : Optional [Union [int , Sequence [int ]]] = None
125
129
126
- stack_trace : Dict [str , str ] = dataclasses .field (default_factory = dict )
130
+ stack_traces : Dict [str , str ] = dataclasses .field (default_factory = dict )
127
131
module_hierarchy : Dict [str , Dict ] = dataclasses .field (default_factory = dict )
128
132
is_delegated_op : Optional [bool ] = None
129
133
delegate_backend_name : Optional [str ] = None
@@ -158,6 +162,33 @@ def _gen_from_profile_events(
158
162
is_delegated_op = is_delegated_op ,
159
163
)
160
164
165
+ def _associate_with_op_graph_nodes (
166
+ self , debug_handle_to_op_node_map : Dict [int , OperatorNode ]
167
+ ) -> None :
168
+ """
169
+ Helper function to populate the stack_traces, module_hierarchy and op_types attributes
170
+ based on the debug handles of this event
171
+ """
172
+ debug_handles = []
173
+ if self .debug_handles is None :
174
+ return
175
+
176
+ if isinstance (self .debug_handles , int ):
177
+ debug_handles = [self .debug_handles ]
178
+ elif isinstance (self .debug_handles , Sequence ):
179
+ debug_handles = self .debug_handles
180
+
181
+ for handle in debug_handles :
182
+ node = debug_handle_to_op_node_map .get (handle )
183
+ if node is not None and node .metadata is not None :
184
+ self .stack_traces [node .name ] = node .metadata .get ("stack_trace" )
185
+ if node .metadata :
186
+ self .module_hierarchy [node .name ] = node .metadata .get (
187
+ "nn_module_stack"
188
+ )
189
+ if node .op :
190
+ self .op_types += [node .op ]
191
+
161
192
162
193
@dataclass
163
194
class EventBlock :
@@ -186,11 +217,11 @@ def to_dataframe(self) -> pd.DataFrame:
186
217
"min" : [event .perf_data .min for event in self .events ],
187
218
"max" : [event .perf_data .max for event in self .events ],
188
219
"median" : [event .perf_data .median for event in self .events ],
189
- "op_type " : [event .op_type for event in self .events ],
220
+ "op_types " : [event .op_types for event in self .events ],
190
221
"delegate_debug_identifier" : [
191
222
event .delegate_debug_identifier for event in self .events
192
223
],
193
- "stack_traces" : [event .stack_trace for event in self .events ],
224
+ "stack_traces" : [event .stack_traces for event in self .events ],
194
225
"module_hierarchy" : [event .module_hierarchy for event in self .events ],
195
226
"is_delegated_op" : [event .is_delegated_op for event in self .events ],
196
227
"delegate_backend_name" : [
@@ -290,6 +321,9 @@ def _gen_resolve_debug_handles(
290
321
)
291
322
292
323
324
+ EDGE_DIALECT_GRAPH_KEY = "edge_dialect_output/forward"
325
+
326
+
293
327
class Inspector :
294
328
"""
295
329
APIs for examining model architecture and performance stats
@@ -301,16 +335,33 @@ def __init__(
301
335
"""
302
336
Create an inspector instance from the provided ETDump/ETRecord
303
337
"""
338
+ etdump , self ._etrecord = gen_etdump_and_etrecord_objects (
339
+ etdump_path = etdump_path , etrecord_path = etrecord_path
340
+ )
341
+
342
+ self .event_blocks = EventBlock ._gen_from_etdump (etdump )
304
343
305
- # Gen op graphs from etrecord
306
- if etrecord_path is not None :
307
- self ._etrecord = parse_etrecord (etrecord_path = etrecord_path )
308
- self ._op_graph_dict : Mapping [
309
- str , OperatorGraphWithStats
310
- ] = gen_graphs_from_etrecord (etrecord = self ._etrecord )
344
+ self ._op_graph_dict : Mapping [
345
+ str , OperatorGraphWithStats
346
+ ] = gen_graphs_from_etrecord (etrecord = self ._etrecord )
311
347
312
- self .event_blocks : List [EventBlock ] = []
313
- # TODO: create event blocks from etdump, and associate events with op graph nodes
348
+ # Use the delegate map from etrecord, associate debug handles with each event
349
+ for event_block in self .event_blocks :
350
+ event_block ._gen_resolve_debug_handles (
351
+ self ._etrecord ._debug_handle_map .get ("forward" ),
352
+ self ._etrecord ._delegate_map .get ("forward" ),
353
+ )
354
+
355
+ # Traverse the edge dialect op graph to create mapping from debug_handle to op node
356
+ debug_handle_to_op_node_map = {}
357
+ create_debug_handle_to_op_node_mapping (
358
+ self ._op_graph_dict [EDGE_DIALECT_GRAPH_KEY ],
359
+ debug_handle_to_op_node_map ,
360
+ )
361
+
362
+ for event_block in self .event_blocks :
363
+ for event in event_block .events :
364
+ event ._associate_with_op_graph_nodes (debug_handle_to_op_node_map )
314
365
315
366
def print_data_tabular (self ) -> None :
316
367
"""
@@ -353,12 +404,13 @@ def write_tensorboard_artifact(self, path: str) -> None:
353
404
# TODO: implement
354
405
pass
355
406
356
- # TODO: add a unittest for this function
357
- def get_exported_program (self , graph : Optional [str ]) -> ExportedProgram :
407
+ def get_exported_program (
408
+ self , graph : Optional [str ] = EDGE_DIALECT_GRAPH_KEY
409
+ ) -> ExportedProgram :
358
410
"""
359
411
Access helper for ETRecord, defaults to returning Edge Dialect Program
412
+
413
+ Args:
414
+ graph: Name of the graph to access, defaults to "edge_dialect_output/forward"
360
415
"""
361
- if not graph :
362
- return self ._etrecord ["edge_dialect_output/forward" ]
363
- else :
364
- return self ._etrecord .get (graph )
416
+ return self ._etrecord .graph_map .get (graph )
0 commit comments