@@ -129,28 +129,34 @@ def median(self) -> float:
129
129
@dataclass
130
130
class Event :
131
131
"""
132
- Corresponds to an op instance
132
+ An Event corresponds to an operator instance with perf data retrieved from the runtime and other metadata from `ETRecord`.
133
+
134
+ Args:
135
+ name: Name of the profiling/debugging `Event`.
136
+ perf_data: Performance data associated with the event retrived from the runtime (available attributes: p50, p90, avg, min, max and median).
137
+ op_type: List of op types corresponding to the event.
138
+ delegate_debug_identifier: Supplemental identifier used in combination with instruction id.
139
+ debug_handles: Debug handles in the model graph to which this event is correlated.
140
+ stack_trace: A dictionary mapping the name of each associated op to its stack trace.
141
+ module_hierarchy: A dictionary mapping the name of each associated op to its module hierarchy.
142
+ is_delegated_op: Whether or not the event was delegated.
143
+ delegate_backend_name: Name of the backend this event was delegated to.
144
+ debug_data: Intermediate data collected during runtime.
133
145
"""
134
146
135
147
name : str
136
148
perf_data : PerfData
137
149
op_types : List [str ] = dataclasses .field (default_factory = list )
138
-
139
- # Instruction Id of the original profiling event
140
- instruction_id : Optional [int ] = None
141
-
142
- # Supplemental Identifier used in combination with instruction_identifier
143
150
delegate_debug_identifier : Optional [Union [int , str ]] = None
144
-
145
- # Debug Handles in the model graph to which this event is correlated
146
151
debug_handles : Optional [Union [int , Sequence [int ]]] = None
147
-
148
152
stack_traces : Dict [str , str ] = dataclasses .field (default_factory = dict )
149
153
module_hierarchy : Dict [str , Dict ] = dataclasses .field (default_factory = dict )
150
154
is_delegated_op : Optional [bool ] = None
151
155
delegate_backend_name : Optional [str ] = None
152
156
debug_data : List [torch .Tensor ] = dataclasses .field (default_factory = list )
153
157
158
+ _instruction_id : Optional [int ] = None
159
+
154
160
@staticmethod
155
161
def _gen_from_profile_events (
156
162
signature : ProfileEventSignature ,
@@ -183,9 +189,9 @@ def _gen_from_profile_events(
183
189
return Event (
184
190
name = name ,
185
191
perf_data = perf_data ,
186
- instruction_id = signature .instruction_id ,
187
192
delegate_debug_identifier = delegate_debug_identifier ,
188
193
is_delegated_op = is_delegated_op ,
194
+ _instruction_id = signature .instruction_id ,
189
195
)
190
196
191
197
def _associate_with_op_graph_nodes (
@@ -213,11 +219,14 @@ def _associate_with_op_graph_nodes(
213
219
214
220
@dataclass
215
221
class EventBlock :
216
- """
217
- EventBlock contains a collection of events associated with a particular profiling/debugging block retrieved from the runtime.
218
- Attributes:
219
- name (str): Name of the profiling/debugging block
220
- events (List[Event]): List of events associated with the profiling/debugging block
222
+ r"""
223
+ An `EventBlock` contains a collection of events associated with a particular profiling/debugging block retrieved from the runtime.
224
+ Each `EventBlock` represents a pattern of execution. For example, model initiation and loading lives in a single `EventBlock`.
225
+ If there's a control flow, each branch will be represented by a separate `EventBlock`.
226
+
227
+ Args:
228
+ name: Name of the profiling/debugging block.
229
+ events: List of `Event`\ s associated with the profiling/debugging block.
221
230
"""
222
231
223
232
name : str
@@ -226,7 +235,14 @@ class EventBlock:
226
235
def to_dataframe (self ) -> pd .DataFrame :
227
236
"""
228
237
Converts the EventBlock into a DataFrame with each row being an event instance
238
+
239
+ Args:
240
+ None
241
+
242
+ Returns:
243
+ A Pandas DataFrame containing the data of each Event instance in this EventBlock.
229
244
"""
245
+
230
246
# TODO: push row generation down to Event
231
247
data = {
232
248
"event_block_name" : [self .name ] * len (self .events ),
@@ -320,11 +336,11 @@ def _gen_resolve_debug_handles(
320
336
"""
321
337
for event in self .events :
322
338
# Check if instruction_id is present in the event
323
- if event .instruction_id is None :
339
+ if event ._instruction_id is None :
324
340
continue
325
341
326
342
# Check for the instruction_id in handle map
327
- if (instruction_id := str (event .instruction_id )) not in handle_map :
343
+ if (instruction_id := str (event ._instruction_id )) not in handle_map :
328
344
continue
329
345
330
346
# For non-delegated event, handles are found in handle_map
@@ -339,7 +355,7 @@ def _gen_resolve_debug_handles(
339
355
):
340
356
event .debug_handles = handle_map [instruction_id ]
341
357
log .warning (
342
- f" No delegate mapping found for delegate with instruction id { event .instruction_id } "
358
+ f" No delegate mapping found for delegate with instruction id { event ._instruction_id } "
343
359
)
344
360
continue
345
361
@@ -376,14 +392,18 @@ def __init__(
376
392
etrecord_path : Optional [str ] = None ,
377
393
etdump_scale : int = 1000 ,
378
394
) -> None :
379
- """
380
- Create an inspector instance from the provided ETDump/ETRecord
395
+ r"""
396
+ Initialize an `Inspector` instance with the underlying `EventBlock`\ s populated with data from the provided ETDump path
397
+ and optional ETRecord path.
381
398
382
399
Args:
383
400
etdump_path: Path to the ETDump file.
384
- etrecord_path: Path to the ETRecord file.
401
+ etrecord_path: Optional path to the ETRecord file.
385
402
etdump_scale: Inverse Scale Factor used to cast the timestamps in ETDump
386
403
defaults to milli (1000ms = 1s).
404
+
405
+ Returns:
406
+ None
387
407
"""
388
408
389
409
self ._etrecord = (
@@ -422,7 +442,13 @@ def __init__(
422
442
423
443
def print_data_tabular (self ) -> None :
424
444
"""
425
- Prints the underlying EventBlocks (essentially all the performance data)
445
+ Displays the underlying EventBlocks in a structured tabular format, with each row representing an Event.
446
+
447
+ Args:
448
+ None
449
+
450
+ Returns:
451
+ None
426
452
"""
427
453
428
454
def style_text_size (val , size = 12 ):
@@ -447,7 +473,17 @@ def style_text_size(val, size=12):
447
473
print (tabulate (filtered_df , headers = "keys" , tablefmt = "fancy_grid" ))
448
474
449
475
# TODO: write unit test
450
- def find_total_for_module (self , module_name : str ):
476
+ def find_total_for_module (self , module_name : str ) -> float :
477
+ """
478
+ Returns the total average compute time of all operators within the specified module.
479
+
480
+ Args:
481
+ module_name: Name of the module to be aggregated against.
482
+
483
+ Returns:
484
+ Sum of the average compute time (in seconds) of all operators within the module with "module_name".
485
+ """
486
+
451
487
total = 0.0
452
488
for block in self .event_blocks :
453
489
for event in block .events :
@@ -481,10 +517,13 @@ def get_exported_program(
481
517
self , graph : Optional [str ] = None
482
518
) -> Optional [ExportedProgram ]:
483
519
"""
484
- Access helper for ETRecord, defaults to returning Edge Dialect Program
520
+ Access helper for ETRecord, defaults to returning the Edge Dialect program.
485
521
486
522
Args:
487
- graph: Name of the graph to access. If None, returns the Edge Dialect Program.
523
+ graph: Optional name of the graph to access. If None, returns the Edge Dialect program.
524
+
525
+ Returns:
526
+ The ExportedProgram object of "graph".
488
527
"""
489
528
if self ._etrecord is None :
490
529
log .warning (
0 commit comments