@@ -30,7 +30,7 @@ def __init__(self, new_cudagraphs: bool, new_pre_allocated_output: bool):
30
30
# Indicates whether pre-allocated output was enabled in the previous execute_engine
31
31
self .old_pre_allocated_outputs = new_pre_allocated_output
32
32
33
- def validate_states (
33
+ def set_runtime_states (
34
34
self ,
35
35
new_cudagraphs : bool ,
36
36
new_pre_allocated_output : bool ,
@@ -144,8 +144,11 @@ def __init__(
144
144
self .engine = None
145
145
self .weight_name_map = weight_name_map
146
146
self .target_platform = Platform .current_platform ()
147
- # Previous cuda graphs state
148
- self .prev_cudagraphs_enabled = False
147
+ self .runtime_states = TorchTRTRuntimeStates (
148
+ torch_tensorrt .runtime .get_cudagraphs_mode (), False
149
+ )
150
+ self .pre_allocated_outputs : List [torch .Tensor ] = []
151
+ self .use_pre_allocated_outputs = False
149
152
150
153
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
151
154
self .setup_engine ()
@@ -352,14 +355,16 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
352
355
self ._check_initialized ()
353
356
354
357
cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
355
- shape_changed = self .cudagraphs_validate_shapes (inputs )
356
- # Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
357
- need_cudagraphs_record = cudagraphs_enabled and (
358
- (not self .prev_cudagraphs_enabled ) or (not shape_changed )
358
+ shape_changed = self .validate_input_shapes (inputs )
359
+ need_cudagraphs_record , can_use_pre_allocated_outputs = (
360
+ self .runtime_states .set_runtime_states (
361
+ cudagraphs_enabled , self .use_pre_allocated_outputs , shape_changed
362
+ )
359
363
)
360
- self .prev_cudagraphs_enabled = cudagraphs_enabled
361
364
362
365
if need_cudagraphs_record :
366
+ if self .cudagraph :
367
+ self .cudagraph .reset ()
363
368
self ._input_buffers = [None ] * len (self .input_names )
364
369
self ._output_buffers = [None ] * len (self .output_names )
365
370
@@ -423,8 +428,16 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
423
428
This could happen if the input tensor addresses/shapes haven't been configured correctly"
424
429
)
425
430
426
- with nvtx .annotate ("ProcessOutputs:1" , color = "red" ):
427
- if not self .use_pre_allocated_outputs or shape_changed :
431
+ with (
432
+ torch .autograd .profiler .record_function (
433
+ "PythonTorchTensorRTModule:ProcessOutputs"
434
+ )
435
+ if self .profiling_enabled
436
+ else nullcontext ()
437
+ ):
438
+ if can_use_pre_allocated_outputs :
439
+ outputs = self .pre_allocated_outputs
440
+ else :
428
441
self .output_shapes = [
429
442
tuple (self .context .get_tensor_shape (output_name ))
430
443
for output_name in self .output_names
@@ -434,12 +447,12 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
434
447
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
435
448
)
436
449
outputs = self .create_output_tensors ()
437
- else :
438
- outputs = self .pre_allocated_outputs
439
450
440
451
for o , output_name in enumerate (self .output_names ):
452
+
441
453
if need_cudagraphs_record :
442
454
self ._output_buffers [o ] = outputs [o ].clone ()
455
+
443
456
if cudagraphs_enabled :
444
457
self .context .set_tensor_address (
445
458
output_name , self ._output_buffers [o ].data_ptr ()
@@ -449,7 +462,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
449
462
output_name , outputs [o ].data_ptr ()
450
463
)
451
464
452
- with nvtx .annotate ("TensorRTRuntime" , color = "red" ):
465
+ with (
466
+ torch .autograd .profiler .record_function (
467
+ "PythonTorchTensorRTModule:TensorRTRuntime"
468
+ )
469
+ if self .profiling_enabled
470
+ else nullcontext ()
471
+ ):
453
472
self ._caller_stream = torch .cuda .current_stream ()
454
473
if (
455
474
self ._engine_stream == torch .cuda .default_stream ()
@@ -490,6 +509,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
490
509
491
510
self ._caller_stream .wait_stream (self ._engine_stream )
492
511
512
+ if self .use_pre_allocated_outputs :
513
+ self .pre_allocated_outputs = self .create_output_tensors ()
514
+
493
515
if cudagraphs_enabled :
494
516
for idx , o in enumerate (outputs ):
495
517
o .copy_ (self ._output_buffers [idx ])
@@ -531,10 +553,9 @@ def get_layer_info(self) -> str:
531
553
)
532
554
return engine_json
533
555
534
- def cudagraphs_validate_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
556
+ def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
535
557
"""
536
- Validates the input shapes of the forward function
537
- versus the version currently active for the
558
+ Validates the input shapes of the forward function has changed
538
559
"""
539
560
# Representation of input shapes to a given model
540
561
# Shapes are concatenated as so:
@@ -544,10 +565,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
544
565
# If the new shape key differs from the existing one,
545
566
# invalidate the old shape key and remove the CUDAGraph
546
567
if new_shape_key != self .shape_key :
547
- logger .debug (f"Resetting Cudagraph on new shape key { new_shape_key } " )
568
+ logger .debug (f"Input shape changed { self . shape_key } -> { new_shape_key } " )
548
569
self .shape_key = new_shape_key
549
- if self .cudagraph :
550
- self .cudagraph .reset ()
551
- return False
570
+ return True
552
571
553
- return True
572
+ return False
0 commit comments