25
25
import torch
26
26
from executorch .exir import ExportedProgram
27
27
28
- from executorch .sdk .edir .et_schema import OperatorGraphWithStats
29
- from executorch .sdk .etdb ._inspector_utils import gen_graphs_from_etrecord
28
+ from executorch .sdk .edir .et_schema import OperatorGraphWithStats , OperatorNode
29
+ from executorch .sdk .etdb ._inspector_utils import (
30
+ create_debug_handle_to_op_node_mapping ,
31
+ gen_etdump_object ,
32
+ gen_etrecord_object ,
33
+ gen_graphs_from_etrecord ,
34
+ )
30
35
from executorch .sdk .etdump .schema_flatcc import ETDumpFlatCC , ProfileEvent
31
- from executorch . sdk . etrecord import parse_etrecord
36
+
32
37
from tabulate import tabulate
33
38
39
+
40
+ FORWARD = "forward"
41
+ RESERVED_SPECIAL_EVENT_NAMES = [
42
+ "Method::init" ,
43
+ "Program::load_method" ,
44
+ "Method::execute" ,
45
+ ]
46
+ EXCLUDED_COLUMNS_WHEN_PRINTING = [
47
+ "raw" ,
48
+ "delegate_debug_identifier" ,
49
+ "stack_traces" ,
50
+ "module_hierarchy" ,
51
+ "debug_data" ,
52
+ ]
53
+
54
+
34
55
log : logging .Logger = logging .getLogger (__name__ )
35
56
36
57
# Signature of a ProfileEvent
@@ -112,7 +133,7 @@ class Event:
112
133
113
134
name : str
114
135
perf_data : PerfData
115
- op_type : List [str ] = dataclasses .field (default_factory = list )
136
+ op_types : List [str ] = dataclasses .field (default_factory = list )
116
137
117
138
# Instruction Id of the original profiling event
118
139
instruction_id : Optional [int ] = None
@@ -123,7 +144,7 @@ class Event:
123
144
# Debug Handles in the model graph to which this event is correlated
124
145
debug_handles : Optional [Union [int , Sequence [int ]]] = None
125
146
126
- stack_trace : Dict [str , str ] = dataclasses .field (default_factory = dict )
147
+ stack_traces : Dict [str , str ] = dataclasses .field (default_factory = dict )
127
148
module_hierarchy : Dict [str , Dict ] = dataclasses .field (default_factory = dict )
128
149
is_delegated_op : Optional [bool ] = None
129
150
delegate_backend_name : Optional [str ] = None
@@ -138,9 +159,10 @@ def _gen_from_profile_events(
138
159
return an Event object matching the ProfileEventSignature, with perf_data
139
160
populated from the list of ProfileEvents
140
161
"""
141
- delegate_debug_identifier = (
142
- signature .delegate_id or signature .delegate_id_str or None
143
- )
162
+ if signature .delegate_id is not None : # 0 is a valid value
163
+ delegate_debug_identifier = signature .delegate_id
164
+ else :
165
+ delegate_debug_identifier = signature .delegate_id_str or None
144
166
145
167
# Use the delegate identifier as the event name if delegated
146
168
is_delegated_op = delegate_debug_identifier is not None
@@ -158,6 +180,28 @@ def _gen_from_profile_events(
158
180
is_delegated_op = is_delegated_op ,
159
181
)
160
182
183
+ def _associate_with_op_graph_nodes (
184
+ self , debug_handle_to_op_node_map : Dict [int , OperatorNode ]
185
+ ) -> None :
186
+ """
187
+ Helper function to populate the stack_traces, module_hierarchy and op_types attributes
188
+ based on the debug handles of this event
189
+ """
190
+ if (debug_handles := self .debug_handles ) is None :
191
+ return
192
+
193
+ if isinstance (debug_handles , int ):
194
+ debug_handles = [debug_handles ]
195
+
196
+ for handle in debug_handles :
197
+ node = debug_handle_to_op_node_map .get (handle )
198
+ if node is not None and (metadata := node .metadata ) is not None :
199
+ self .stack_traces [node .name ] = metadata .get ("stack_trace" )
200
+ self .module_hierarchy [node .name ] = metadata .get ("nn_module_stack" )
201
+ if node .op :
202
+ # TODO: consider having this as a dict from node.name -> node.op
203
+ self .op_types += [node .op ]
204
+
161
205
162
206
@dataclass
163
207
class EventBlock :
@@ -186,11 +230,11 @@ def to_dataframe(self) -> pd.DataFrame:
186
230
"min" : [event .perf_data .min for event in self .events ],
187
231
"max" : [event .perf_data .max for event in self .events ],
188
232
"median" : [event .perf_data .median for event in self .events ],
189
- "op_type " : [event .op_type for event in self .events ],
233
+ "op_types " : [event .op_types for event in self .events ],
190
234
"delegate_debug_identifier" : [
191
235
event .delegate_debug_identifier for event in self .events
192
236
],
193
- "stack_traces" : [event .stack_trace for event in self .events ],
237
+ "stack_traces" : [event .stack_traces for event in self .events ],
194
238
"module_hierarchy" : [event .module_hierarchy for event in self .events ],
195
239
"is_delegated_op" : [event .is_delegated_op for event in self .events ],
196
240
"delegate_backend_name" : [
@@ -248,10 +292,11 @@ def _gen_from_etdump(etdump: ETDumpFlatCC) -> List["EventBlock"]:
248
292
for index , profile_events in enumerate (profile_run_groups .values ())
249
293
]
250
294
295
+ # TODO: Considering changing ETRecord deserialization logic to cast the ints in string format to actual ints
251
296
def _gen_resolve_debug_handles (
252
297
self ,
253
- handle_map : Dict [int , List [int ]],
254
- delegate_map : Optional [Dict [int , DelegateMetadata ]] = None ,
298
+ handle_map : Dict [str , List [int ]],
299
+ delegate_map : Optional [Dict [str , DelegateMetadata ]] = None ,
255
300
):
256
301
"""
257
302
Given mappings from instruction id to debug handles, populate the
@@ -261,10 +306,12 @@ def _gen_resolve_debug_handles(
261
306
to obtain the debug_handle via the delegate map
262
307
"""
263
308
for event in self .events :
309
+ # Check if instruction_id is present in the event
310
+ if event .instruction_id is None :
311
+ continue
312
+
264
313
# Check for the instruction_id in handle map
265
- if (
266
- instruction_id := event .instruction_id
267
- ) is None or instruction_id not in handle_map :
314
+ if (instruction_id := str (event .instruction_id )) not in handle_map :
268
315
continue
269
316
270
317
# For non-delegated event, handles are found in handle_map
@@ -285,14 +332,33 @@ def _gen_resolve_debug_handles(
285
332
286
333
# For delegated events, handles are found via delegateMetadata
287
334
event .delegate_backend_name = delegate_metadata .get ("name" , "" )
288
- event .debug_handles = delegate_metadata .get ("delegate_map" , {}).get (
335
+ delegate_metadata_delegate_map = delegate_metadata .get ("delegate_map" , {})
336
+
337
+ # delegate_debug_id can be either int based or string based, therefore we need to check both
338
+ debug_handles = delegate_metadata_delegate_map .get (
289
339
delegate_debug_id # pyre-ignore
290
340
)
341
+ if debug_handles is not None :
342
+ event .debug_handles = debug_handles
343
+ else :
344
+ event .debug_handles = delegate_metadata_delegate_map .get (
345
+ str (delegate_debug_id ) # pyre-ignore
346
+ )
347
+
348
+
349
+ EDGE_DIALECT_GRAPH_KEY = "edge_dialect_output/forward"
291
350
292
351
293
352
class Inspector :
294
353
"""
295
- APIs for examining model architecture and performance stats
354
+ APIs for examining model architecture and performance stats.
355
+
356
+ Public Attributes:
357
+ event_blocks: List["EventBlocks"]. Structured data accessible through Inspector for analysis.
358
+
359
+ Private Attributes:
360
+ _etrecord: Optional[ETRecord]. File under etrecord_path deserialized into an object.
361
+ _op_graph_dict: Mapping[str, OperatorGraphWithStats]. Graph objects parsed from etrecord matched with user defined graph names.
296
362
"""
297
363
298
364
def __init__ (
@@ -302,15 +368,34 @@ def __init__(
302
368
Create an inspector instance from the provided ETDump/ETRecord
303
369
"""
304
370
305
- # Gen op graphs from etrecord
306
- if etrecord_path is not None :
307
- self ._etrecord = parse_etrecord (etrecord_path = etrecord_path )
308
- self ._op_graph_dict : Mapping [
309
- str , OperatorGraphWithStats
310
- ] = gen_graphs_from_etrecord (etrecord = self ._etrecord )
371
+ # TODO: etrecord_path can be optional, so need to support the case when it is not present
372
+ self ._etrecord = gen_etrecord_object (etrecord_path = etrecord_path )
373
+ etdump = gen_etdump_object (etdump_path = etdump_path )
374
+ self .event_blocks = EventBlock ._gen_from_etdump (etdump )
375
+
376
+ self ._op_graph_dict : Mapping [
377
+ str , OperatorGraphWithStats
378
+ ] = gen_graphs_from_etrecord (etrecord = self ._etrecord )
379
+
380
+ # Use the delegate map from etrecord, associate debug handles with each event
381
+ for event_block in self .event_blocks :
382
+ event_block ._gen_resolve_debug_handles (
383
+ self ._etrecord ._debug_handle_map [FORWARD ],
384
+ self ._etrecord ._delegate_map [FORWARD ]
385
+ if self ._etrecord ._delegate_map is not None
386
+ else None ,
387
+ )
311
388
312
- self .event_blocks : List [EventBlock ] = []
313
- # TODO: create event blocks from etdump, and associate events with op graph nodes
389
+ # Traverse the edge dialect op graph to create mapping from debug_handle to op node
390
+ debug_handle_to_op_node_map = {}
391
+ create_debug_handle_to_op_node_mapping (
392
+ self ._op_graph_dict [EDGE_DIALECT_GRAPH_KEY ],
393
+ debug_handle_to_op_node_map ,
394
+ )
395
+
396
+ for event_block in self .event_blocks :
397
+ for event in event_block .events :
398
+ event ._associate_with_op_graph_nodes (debug_handle_to_op_node_map )
314
399
315
400
def print_data_tabular (self ) -> None :
316
401
"""
@@ -322,14 +407,31 @@ def style_text_size(val, size=12):
322
407
323
408
df_list = [event_block .to_dataframe () for event_block in self .event_blocks ]
324
409
combined_df = pd .concat (df_list , ignore_index = True )
325
- # TODO: filter out raw, delegate_debug_identifier, stack_traces and module_hierarchy
410
+ # Filter out some columns for better readability when printing
411
+ filtered_df = combined_df .drop (columns = EXCLUDED_COLUMNS_WHEN_PRINTING )
326
412
try :
327
413
from IPython .display import display
328
414
329
- styled_df = combined_df .style .applymap (style_text_size )
415
+ styled_df = filtered_df .style .applymap (style_text_size )
330
416
display (styled_df )
331
417
except :
332
- print (tabulate (combined_df , headers = "keys" , tablefmt = "fancy_grid" ))
418
+ # TODO: figure out how to trigger this path in python shell
419
+ print (tabulate (filtered_df , headers = "keys" , tablefmt = "fancy_grid" ))
420
+
421
+ # TODO: write unit test
422
+ def find_total_for_module (self , module_name : str ):
423
+ total = 0.0
424
+ for block in self .event_blocks :
425
+ for event in block .events :
426
+ module_hierarchy = event .module_hierarchy .values ()
427
+ for hierarchy in module_hierarchy :
428
+ if not hierarchy :
429
+ continue
430
+ found = any (module_name in key for key in hierarchy .keys ())
431
+ if found :
432
+ total += event .perf_data .avg
433
+ break
434
+ return total
333
435
334
436
def get_event_blocks (self ) -> List [EventBlock ]:
335
437
"""
@@ -353,12 +455,13 @@ def write_tensorboard_artifact(self, path: str) -> None:
353
455
# TODO: implement
354
456
pass
355
457
356
- # TODO: add a unittest for this function
357
- def get_exported_program (self , graph : Optional [str ]) -> ExportedProgram :
458
+ def get_exported_program (
459
+ self , graph : Optional [str ] = EDGE_DIALECT_GRAPH_KEY
460
+ ) -> ExportedProgram :
358
461
"""
359
462
Access helper for ETRecord, defaults to returning Edge Dialect Program
463
+
464
+ Args:
465
+ graph: Name of the graph to access, defaults to "edge_dialect_output/forward"
360
466
"""
361
- if not graph :
362
- return self ._etrecord ["edge_dialect_output/forward" ]
363
- else :
364
- return self ._etrecord .get (graph )
467
+ return self ._etrecord .graph_map .get (graph )
0 commit comments