25
25
import torch
26
26
from executorch .exir import ExportedProgram
27
27
28
- from executorch .sdk .edir .et_schema import OperatorGraphWithStats
29
- from executorch .sdk .etdb ._inspector_utils import gen_graphs_from_etrecord
28
+ from executorch .sdk .edir .et_schema import OperatorGraphWithStats , OperatorNode
29
+ from executorch .sdk .etdb ._inspector_utils import (
30
+ create_debug_handle_to_op_node_mapping ,
31
+ gen_etdump_object ,
32
+ gen_etrecord_object ,
33
+ gen_graphs_from_etrecord ,
34
+ )
30
35
from executorch .sdk .etdump .schema_flatcc import ETDumpFlatCC , ProfileEvent
31
- from executorch . sdk . etrecord import parse_etrecord
36
+
32
37
from tabulate import tabulate
33
38
39
+
40
+ FORWARD = "forward"
41
+ RESERVED_SPECIAL_EVENT_NAMES = [
42
+ "Method::init" ,
43
+ "Program::load_method" ,
44
+ "Method::execute" ,
45
+ ]
46
+
47
+
34
48
log : logging .Logger = logging .getLogger (__name__ )
35
49
36
50
# Signature of a ProfileEvent
@@ -112,7 +126,7 @@ class Event:
112
126
113
127
name : str
114
128
perf_data : PerfData
115
- op_type : List [str ] = dataclasses .field (default_factory = list )
129
+ op_types : List [str ] = dataclasses .field (default_factory = list )
116
130
117
131
# Instruction Id of the original profiling event
118
132
instruction_id : Optional [int ] = None
@@ -123,7 +137,7 @@ class Event:
123
137
# Debug Handles in the model graph to which this event is correlated
124
138
debug_handles : Optional [Union [int , Sequence [int ]]] = None
125
139
126
- stack_trace : Dict [str , str ] = dataclasses .field (default_factory = dict )
140
+ stack_traces : Dict [str , str ] = dataclasses .field (default_factory = dict )
127
141
module_hierarchy : Dict [str , Dict ] = dataclasses .field (default_factory = dict )
128
142
is_delegated_op : Optional [bool ] = None
129
143
delegate_backend_name : Optional [str ] = None
@@ -138,9 +152,10 @@ def _gen_from_profile_events(
138
152
return an Event object matching the ProfileEventSignature, with perf_data
139
153
populated from the list of ProfileEvents
140
154
"""
141
- delegate_debug_identifier = (
142
- signature .delegate_id or signature .delegate_id_str or None
143
- )
155
+ if signature .delegate_id is not None : # 0 is a valid value
156
+ delegate_debug_identifier = signature .delegate_id
157
+ else :
158
+ delegate_debug_identifier = signature .delegate_id_str or None
144
159
145
160
# Use the delegate identifier as the event name if delegated
146
161
is_delegated_op = delegate_debug_identifier is not None
@@ -158,6 +173,28 @@ def _gen_from_profile_events(
158
173
is_delegated_op = is_delegated_op ,
159
174
)
160
175
176
+ def _associate_with_op_graph_nodes (
177
+ self , debug_handle_to_op_node_map : Dict [int , OperatorNode ]
178
+ ) -> None :
179
+ """
180
+ Helper function to populate the stack_traces, module_hierarchy and op_types attributes
181
+ based on the debug handles of this event
182
+ """
183
+ if (debug_handles := self .debug_handles ) is None :
184
+ return
185
+
186
+ if isinstance (debug_handles , int ):
187
+ debug_handles = [debug_handles ]
188
+
189
+ for handle in debug_handles :
190
+ node = debug_handle_to_op_node_map .get (handle )
191
+ if node is not None and (metadata := node .metadata ) is not None :
192
+ self .stack_traces [node .name ] = metadata .get ("stack_trace" )
193
+ self .module_hierarchy [node .name ] = metadata .get ("nn_module_stack" )
194
+ if node .op :
195
+ # TODO: consider having this as a dict from node.name -> node.op
196
+ self .op_types += [node .op ]
197
+
161
198
162
199
@dataclass
163
200
class EventBlock :
@@ -186,11 +223,11 @@ def to_dataframe(self) -> pd.DataFrame:
186
223
"min" : [event .perf_data .min for event in self .events ],
187
224
"max" : [event .perf_data .max for event in self .events ],
188
225
"median" : [event .perf_data .median for event in self .events ],
189
- "op_type " : [event .op_type for event in self .events ],
226
+ "op_types " : [event .op_types for event in self .events ],
190
227
"delegate_debug_identifier" : [
191
228
event .delegate_debug_identifier for event in self .events
192
229
],
193
- "stack_traces" : [event .stack_trace for event in self .events ],
230
+ "stack_traces" : [event .stack_traces for event in self .events ],
194
231
"module_hierarchy" : [event .module_hierarchy for event in self .events ],
195
232
"is_delegated_op" : [event .is_delegated_op for event in self .events ],
196
233
"delegate_backend_name" : [
@@ -250,8 +287,8 @@ def _gen_from_etdump(etdump: ETDumpFlatCC) -> List["EventBlock"]:
250
287
251
288
def _gen_resolve_debug_handles (
252
289
self ,
253
- handle_map : Dict [int , List [int ]],
254
- delegate_map : Optional [Dict [int , DelegateMetadata ]] = None ,
290
+ handle_map : Dict [str , List [int ]],
291
+ delegate_map : Optional [Dict [str , DelegateMetadata ]] = None ,
255
292
):
256
293
"""
257
294
Given mappings from instruction id to debug handles, populate the
@@ -263,7 +300,7 @@ def _gen_resolve_debug_handles(
263
300
for event in self .events :
264
301
# Check for the instruction_id in handle map
265
302
if (
266
- instruction_id := event .instruction_id
303
+ instruction_id := str ( event .instruction_id )
267
304
) is None or instruction_id not in handle_map :
268
305
continue
269
306
@@ -285,14 +322,31 @@ def _gen_resolve_debug_handles(
285
322
286
323
# For delegated events, handles are found via delegateMetadata
287
324
event .delegate_backend_name = delegate_metadata .get ("name" , "" )
288
- event .debug_handles = delegate_metadata .get ("delegate_map" , {}).get (
325
+ delegate_metadata_delegate_map = delegate_metadata .get ("delegate_map" , {})
326
+ debug_handles = delegate_metadata_delegate_map .get (
289
327
delegate_debug_id # pyre-ignore
290
328
)
329
+ if debug_handles is not None :
330
+ event .debug_handles = debug_handles
331
+ else :
332
+ event .debug_handles = delegate_metadata_delegate_map .get (
333
+ str (delegate_debug_id ) # pyre-ignore
334
+ )
335
+
336
+
337
+ EDGE_DIALECT_GRAPH_KEY = "edge_dialect_output/forward"
291
338
292
339
293
340
class Inspector :
294
341
"""
295
- APIs for examining model architecture and performance stats
342
+ APIs for examining model architecture and performance stats.
343
+
344
+ Public Attributes:
345
+ event_blocks: List["EventBlocks"]. Structured data accessible through Inspector for analysis.
346
+
347
+ Private Attributes:
348
+ _etrecord: ETRecord. File under etrecord_path deserialized into an object.
349
+ _op_graph_dict: Mapping[str, OperatorGraphWithStats]. Graph objects parsed from etrecord matched with user defined graph names.
296
350
"""
297
351
298
352
def __init__ (
@@ -302,15 +356,34 @@ def __init__(
302
356
Create an inspector instance from the provided ETDump/ETRecord
303
357
"""
304
358
305
- # Gen op graphs from etrecord
306
- if etrecord_path is not None :
307
- self ._etrecord = parse_etrecord (etrecord_path = etrecord_path )
308
- self ._op_graph_dict : Mapping [
309
- str , OperatorGraphWithStats
310
- ] = gen_graphs_from_etrecord (etrecord = self ._etrecord )
359
+ self ._etrecord = gen_etrecord_object (etrecord_path = etrecord_path )
360
+ etdump = gen_etdump_object (etdump_path = etdump_path )
361
+
362
+ self .event_blocks = EventBlock ._gen_from_etdump (etdump )
363
+
364
+ self ._op_graph_dict : Mapping [
365
+ str , OperatorGraphWithStats
366
+ ] = gen_graphs_from_etrecord (etrecord = self ._etrecord )
311
367
312
- self .event_blocks : List [EventBlock ] = []
313
- # TODO: create event blocks from etdump, and associate events with op graph nodes
368
+ # Use the delegate map from etrecord, associate debug handles with each event
369
+ for event_block in self .event_blocks :
370
+ event_block ._gen_resolve_debug_handles (
371
+ self ._etrecord ._debug_handle_map [FORWARD ],
372
+ self ._etrecord ._delegate_map [FORWARD ]
373
+ if self ._etrecord ._delegate_map is not None
374
+ else None ,
375
+ )
376
+
377
+ # Traverse the edge dialect op graph to create mapping from debug_handle to op node
378
+ debug_handle_to_op_node_map = {}
379
+ create_debug_handle_to_op_node_mapping (
380
+ self ._op_graph_dict [EDGE_DIALECT_GRAPH_KEY ],
381
+ debug_handle_to_op_node_map ,
382
+ )
383
+
384
+ for event_block in self .event_blocks :
385
+ for event in event_block .events :
386
+ event ._associate_with_op_graph_nodes (debug_handle_to_op_node_map )
314
387
315
388
def print_data_tabular (self ) -> None :
316
389
"""
@@ -322,14 +395,38 @@ def style_text_size(val, size=12):
322
395
323
396
df_list = [event_block .to_dataframe () for event_block in self .event_blocks ]
324
397
combined_df = pd .concat (df_list , ignore_index = True )
325
- # TODO: filter out raw, delegate_debug_identifier, stack_traces and module_hierarchy
398
+ # Filter out raw, delegate_debug_identifier, stack_traces, module_hierarchy and debug_data for better readability
399
+ columns_to_drop = [
400
+ "raw" ,
401
+ "delegate_debug_identifier" ,
402
+ "stack_traces" ,
403
+ "module_hierarchy" ,
404
+ "debug_data" ,
405
+ ]
406
+ # Drop the specified columns
407
+ filtered_df = combined_df .drop (columns = columns_to_drop )
326
408
try :
327
409
from IPython .display import display
328
410
329
- styled_df = combined_df .style .applymap (style_text_size )
411
+ styled_df = filtered_df .style .applymap (style_text_size )
330
412
display (styled_df )
331
413
except :
332
- print (tabulate (combined_df , headers = "keys" , tablefmt = "fancy_grid" ))
414
+ print (tabulate (filtered_df , headers = "keys" , tablefmt = "fancy_grid" ))
415
+
416
+ # TODO: write unit test
417
+ def find_total_for_module (self , module_name : str ):
418
+ total = 0.0
419
+ for block in self .event_blocks :
420
+ for event in block .events :
421
+ module_hierarchy = event .module_hierarchy .values ()
422
+ for hierarchy in module_hierarchy :
423
+ if not hierarchy :
424
+ continue
425
+ found = any (module_name in key for key in hierarchy .keys ())
426
+ if found :
427
+ total += event .perf_data .avg
428
+ break
429
+ return total
333
430
334
431
def get_event_blocks (self ) -> List [EventBlock ]:
335
432
"""
@@ -353,12 +450,13 @@ def write_tensorboard_artifact(self, path: str) -> None:
353
450
# TODO: implement
354
451
pass
355
452
356
- # TODO: add a unittest for this function
357
- def get_exported_program (self , graph : Optional [str ]) -> ExportedProgram :
453
+ def get_exported_program (
454
+ self , graph : Optional [str ] = EDGE_DIALECT_GRAPH_KEY
455
+ ) -> ExportedProgram :
358
456
"""
359
457
Access helper for ETRecord, defaults to returning Edge Dialect Program
458
+
459
+ Args:
460
+ graph: Name of the graph to access, defaults to "edge_dialect_output/forward"
360
461
"""
361
- if not graph :
362
- return self ._etrecord ["edge_dialect_output/forward" ]
363
- else :
364
- return self ._etrecord .get (graph )
462
+ return self ._etrecord .graph_map .get (graph )
0 commit comments