1
1
from __future__ import annotations
2
2
3
3
import logging
4
- from contextlib import nullcontext
5
4
from tempfile import tempdir
6
5
from typing import Any , Dict , List , Optional , Sequence , Tuple
7
6
@@ -107,7 +106,10 @@ def __init__(
107
106
self .engine = None
108
107
self .weight_name_map = weight_name_map
109
108
self .target_platform = Platform .current_platform ()
110
- self .cudagraphs_disabled = False
109
+ # Check if CUDA graph capture is enabled in the parent node
110
+ self .cudagraphs_parent_module = False
111
+ self .cudagraphs_enabled = False
112
+ self .persistent_output_buffer = False
111
113
112
114
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
113
115
self .setup_engine ()
@@ -239,23 +241,31 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
239
241
(i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
240
242
for i in inputs
241
243
]
242
- with nvtx .annotate (f "Forward" , color = "red" ):
244
+ with nvtx .annotate ("Forward" , color = "red" ):
243
245
self ._check_initialized ()
244
- cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode () and not self .cudagraphs_disabled
245
-
246
- need_cudagraphs_record = (
247
- cudagraphs_enabled and not self .cudagraphs_validate_shapes (inputs )
246
+ cudagraphs_enabled = (
247
+ torch_tensorrt .runtime .get_cudagraphs_mode ()
248
+ and not self .cudagraphs_parent_module
248
249
)
250
+ shape_changed = self .validate_input_shapes (inputs )
251
+ # Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
252
+ if not self .cudagraphs_enabled and cudagraphs_enabled :
253
+ need_cudagraphs_record = True
254
+ else :
255
+ need_cudagraphs_record = cudagraphs_enabled and shape_changed
256
+ self .cudagraphs_enabled = cudagraphs_enabled
249
257
250
258
if need_cudagraphs_record :
259
+ if self .cudagraph :
260
+ self .cudagraph .reset ()
251
261
self ._input_buffers = [None ] * len (self .input_names )
252
262
self ._output_buffers = [None ] * len (self .output_names )
253
263
254
264
if not cudagraphs_enabled and self .cudagraph :
255
265
self .cudagraph .reset ()
256
266
self .cudagraph = None
257
267
258
- # If in safe mode, check at each iteration for for whether a switch is required
268
+ # If in safe mode, check at each iteration for whether a switch is required
259
269
if (
260
270
torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
261
271
):
@@ -287,7 +297,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
287
297
]
288
298
logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
289
299
290
- with nvtx .annotate (f "ProcessInputs" , color = "red" ):
300
+ with nvtx .annotate ("ProcessInputs" , color = "red" ):
291
301
assert len (contiguous_inputs ) == len (
292
302
self .input_names
293
303
), f"Wrong number of inputs, expect { len (self .input_names )} get { len (contiguous_inputs )} ."
@@ -349,63 +359,66 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
349
359
This could happen if the input tensor addresses/shapes haven't been configured correctly"
350
360
)
351
361
352
- with nvtx .annotate (f "ProcessOutputs" , color = "red" ):
362
+ with nvtx .annotate ("ProcessOutputs" , color = "red" ):
353
363
# create output tensors
354
364
outputs : List [torch .Tensor ] = []
365
+ if not self .persistent_output_buffer or shape_changed :
366
+ # Create and keep persistent output buffer as long as its shape does not change
367
+ self ._output_buffers = []
368
+ for o , output_name in enumerate (self .output_names ):
369
+ shape = tuple (self .context .get_tensor_shape (output_name ))
370
+ if DYNAMIC_DIM in shape :
371
+ raise ValueError (
372
+ "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
373
+ )
355
374
356
- for o , output_name in enumerate (self .output_names ):
357
- shape = tuple (self .context .get_tensor_shape (output_name ))
358
-
359
- if DYNAMIC_DIM in shape :
360
- raise ValueError (
361
- "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
375
+ output = torch .empty (
376
+ size = shape ,
377
+ dtype = self .output_dtypes [o ].to (torch .dtype ),
378
+ device = torch .cuda .current_device (),
362
379
)
363
-
364
- output = torch .empty (
365
- size = shape ,
366
- dtype = self .output_dtypes [o ].to (torch .dtype ),
367
- device = torch .cuda .current_device (),
368
- )
369
-
370
- outputs .append (output )
371
-
372
- if need_cudagraphs_record :
373
- self ._output_buffers [o ] = outputs [o ].clone ()
374
-
375
- if cudagraphs_enabled :
376
- self .context .set_tensor_address (
377
- output_name , self ._output_buffers [o ].data_ptr ()
378
- )
379
- else :
380
- self .context .set_tensor_address (
381
- output_name , outputs [o ].data_ptr ()
382
- )
383
-
384
- with nvtx .annotate (f"TensorRTRuntime" , color = "red" ):
380
+ if self .persistent_output_buffer :
381
+ self .context .set_tensor_address (
382
+ output_name , output .data_ptr ()
383
+ )
384
+ self ._output_buffers .append (output )
385
+ else :
386
+ outputs .append (output )
387
+ if need_cudagraphs_record :
388
+ self ._output_buffers [o ] = outputs [o ].clone ()
389
+ if cudagraphs_enabled :
390
+ self .context .set_tensor_address (
391
+ output_name , self ._output_buffers [o ].data_ptr ()
392
+ )
393
+ else :
394
+ self .context .set_tensor_address (
395
+ output_name , outputs [o ].data_ptr ()
396
+ )
397
+ with nvtx .annotate ("TensorRTRuntime" , color = "red" ):
385
398
self ._caller_stream = torch .cuda .current_stream ()
386
399
if (
387
400
self ._engine_stream == torch .cuda .default_stream ()
388
401
or self ._engine_stream is None
389
402
):
390
403
self ._engine_stream = torch .cuda .Stream ()
391
404
392
- with nvtx .annotate (f "wait_stream" , color = "green" ):
405
+ with nvtx .annotate ("wait_stream" , color = "green" ):
393
406
self ._engine_stream .wait_stream (self ._caller_stream )
394
407
395
408
with torch .cuda .stream (self ._engine_stream ):
396
409
if cudagraphs_enabled :
397
410
if need_cudagraphs_record :
398
- with nvtx .annotate (f "CUDAGraph" , color = "green" ):
411
+ with nvtx .annotate ("CUDAGraph" , color = "green" ):
399
412
self .cudagraph = torch .cuda .CUDAGraph ()
400
413
401
414
if self .profiling_enabled :
402
415
self .cudagraph .enable_debug_mode ()
403
- with nvtx .annotate (f "torch.cuda.graph" , color = "green" ):
416
+ with nvtx .annotate ("torch.cuda.graph" , color = "green" ):
404
417
with torch .cuda .graph (
405
418
self .cudagraph , stream = self ._engine_stream
406
419
):
407
420
with nvtx .annotate (
408
- f "execute_async_v3" , color = "green"
421
+ "execute_async_v3" , color = "green"
409
422
):
410
423
self .context .execute_async_v3 (
411
424
self ._engine_stream .cuda_stream
@@ -418,17 +431,23 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
418
431
self .cudagraph .debug_dump (
419
432
f"{ tempdir } /{ self .name } _cudagraph.dot"
420
433
)
421
- with nvtx .annotate (f "replay" , color = "green" ):
434
+ with nvtx .annotate ("replay" , color = "green" ):
422
435
self .cudagraph .replay () # type: ignore
423
436
424
437
else :
425
438
self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
426
439
427
440
self ._caller_stream .wait_stream (self ._engine_stream )
428
441
429
- if cudagraphs_enabled :
430
- for idx , o in enumerate (outputs ):
431
- o .copy_ (self ._output_buffers [idx ])
442
+ if self .persistent_output_buffer :
443
+ if len (self ._output_buffers ) == 1 :
444
+ return self ._output_buffers [0 ]
445
+
446
+ return self ._output_buffers
447
+ else :
448
+ if cudagraphs_enabled :
449
+ for idx , o in enumerate (outputs ):
450
+ o .copy_ (self ._output_buffers [idx ])
432
451
433
452
if len (outputs ) == 1 :
434
453
return outputs [0 ]
@@ -467,10 +486,9 @@ def get_layer_info(self) -> str:
467
486
)
468
487
return engine_json
469
488
470
- def cudagraphs_validate_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
489
+ def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
471
490
"""
472
- Validates the input shapes of the forward function
473
- versus the version currently active for the
491
+ Validates the input shapes of the forward function has changed
474
492
"""
475
493
# Representation of input shapes to a given model
476
494
# Shapes are concatenated as so:
@@ -480,10 +498,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
480
498
# If the new shape key differs from the existing one,
481
499
# invalidate the old shape key and remove the CUDAGraph
482
500
if new_shape_key != self .shape_key :
483
- logger .debug (f"Resetting Cudagraph on new shape key { new_shape_key } " )
501
+ logger .debug (f"Input shape changed { self . shape_key } -> { new_shape_key } " )
484
502
self .shape_key = new_shape_key
485
- if self .cudagraph :
486
- self .cudagraph .reset ()
487
- return False
503
+ return True
488
504
489
- return True
505
+ return False
0 commit comments