@@ -312,6 +312,9 @@ class Event:
312
312
_instruction_id : Optional [int ] = None
313
313
314
314
_delegate_metadata_parser : Optional [Callable [[List [str ]], Dict [str , Any ]]] = None
315
+ _delegate_time_scale_converter : Optional [
316
+ Callable [[Union [int , str ], Union [int , float ]], Union [int , float ]]
317
+ ] = None
315
318
316
319
@cached_property
317
320
def delegate_debug_metadatas (self ) -> Union [List [str ], Dict [str , Any ]]:
@@ -391,6 +394,9 @@ def _gen_from_inference_events(
391
394
delegate_metadata_parser : Optional [
392
395
Callable [[List [str ]], Dict [str , Any ]]
393
396
] = None ,
397
+ delegate_time_scale_converter : Optional [
398
+ Callable [[Union [int , float ]], Union [int , float ]]
399
+ ] = None ,
394
400
) -> "Event" :
395
401
"""
396
402
Given an EventSignature and a list of Events with that signature,
@@ -411,6 +417,7 @@ def _gen_from_inference_events(
411
417
name = "" ,
412
418
_instruction_id = signature .instruction_id ,
413
419
_delegate_metadata_parser = delegate_metadata_parser ,
420
+ _delegate_time_scale_converter = delegate_time_scale_converter ,
414
421
)
415
422
416
423
# Populate fields from profile events
@@ -476,14 +483,31 @@ def _populate_profiling_related_fields(
476
483
f"Expected exactly one profile event per InstructionEvent when generating Inspector Event, but got { len (profile_events )} "
477
484
)
478
485
486
+ profile_event = profile_events [0 ]
487
+
479
488
# Scale factor should only be applied to non-delegated ops
480
- scale_factor_updated = 1 if ret_event .is_delegated_op else scale_factor
489
+ if (
490
+ ret_event .is_delegated_op
491
+ and ret_event ._delegate_time_scale_converter is not None
492
+ ):
493
+ scaled_time = ret_event ._delegate_time_scale_converter (
494
+ ret_event .name ,
495
+ profile_event .end_time ,
496
+ # pyre-ignore
497
+ ) - ret_event ._delegate_time_scale_converter (
498
+ ret_event .name , profile_event .start_time
499
+ )
500
+ elif not ret_event .is_delegated_op :
501
+ scaled_time = (
502
+ float (profile_event .end_time - profile_event .start_time )
503
+ / scale_factor
504
+ )
505
+ else :
506
+ scaled_time = float (
507
+ profile_event .end_time - profile_event .start_time
508
+ )
481
509
482
- profile_event = profile_events [0 ]
483
- data .append (
484
- float (profile_event .end_time - profile_event .start_time )
485
- / scale_factor_updated
486
- )
510
+ data .append (scaled_time )
487
511
delegate_debug_metadatas .append (
488
512
profile_event .delegate_debug_metadata
489
513
if profile_event .delegate_debug_metadata
@@ -646,6 +670,9 @@ def _gen_from_etdump(
646
670
delegate_metadata_parser : Optional [
647
671
Callable [[List [str ]], Dict [str , Any ]]
648
672
] = None ,
673
+ delegate_time_scale_converter : Optional [
674
+ Callable [[Union [int , float ]], Union [int , float ]]
675
+ ] = None ,
649
676
) -> List ["EventBlock" ]:
650
677
"""
651
678
Given an etdump, generate a list of EventBlocks corresponding to the
@@ -743,6 +770,7 @@ class GroupedRunInstances:
743
770
scale_factor ,
744
771
output_buffer ,
745
772
delegate_metadata_parser ,
773
+ delegate_time_scale_converter ,
746
774
)
747
775
for signature , instruction_events in run_group .items ()
748
776
]
@@ -875,6 +903,9 @@ def __init__(
875
903
delegate_metadata_parser : Optional [
876
904
Callable [[List [str ]], Dict [str , Any ]]
877
905
] = None ,
906
+ delegate_time_scale_converter : Optional [
907
+ Callable [[Union [int , float ]], Union [int , float ]]
908
+ ] = None ,
878
909
enable_module_hierarchy : bool = False ,
879
910
) -> None :
880
911
r"""
@@ -930,6 +961,7 @@ def __init__(
930
961
self ._target_time_scale ,
931
962
output_buffer ,
932
963
delegate_metadata_parser = delegate_metadata_parser ,
964
+ delegate_time_scale_converter = delegate_time_scale_converter ,
933
965
)
934
966
935
967
# Connect ETRecord to EventBlocks
0 commit comments