1
1
from __future__ import annotations
2
2
3
3
import logging
4
+ from contextlib import nullcontext
4
5
from tempfile import tempdir
5
6
from typing import Any , Dict , List , Optional , Sequence , Tuple
6
7
7
- import nvtx
8
8
import tensorrt as trt
9
9
import torch
10
10
import torch_tensorrt
@@ -107,10 +107,8 @@ def __init__(
107
107
self .weight_name_map = weight_name_map
108
108
self .target_platform = Platform .current_platform ()
109
109
# Check if CUDA graph capture is enabled in the parent node
110
- self .cudagraphs_parent_module = False
110
+ self .cudagraphs_enabled_parent_module = False
111
111
self .cudagraphs_enabled = False
112
- self .pre_allocated_outputs : List [torch .Tensor ] = []
113
- self .use_pre_allocated_outputs = False
114
112
115
113
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
116
114
self .setup_engine ()
@@ -236,42 +234,34 @@ def __del__(self) -> None:
236
234
if self .cudagraph :
237
235
self .cudagraph .reset ()
238
236
239
- def create_output_tensors (self ) -> List [torch .Tensor ]:
240
- # create output tensors
241
- outputs : List [torch .Tensor ] = []
242
-
243
- for o , _ in enumerate (self .output_names ):
244
- output = torch .empty (
245
- size = self .output_shapes [o ],
246
- dtype = self .output_dtypes [o ],
247
- device = torch .cuda .current_device (),
248
- )
249
- outputs .append (output )
250
- return outputs
251
-
252
237
def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
253
238
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
254
239
contiguous_inputs : List [torch .Tensor ] = [
255
240
(i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
256
241
for i in inputs
257
242
]
258
- with nvtx .annotate ("Forward" , color = "red" ):
243
+
244
+ with (
245
+ torch .autograd .profiler .record_function ("PythonTorchTensorRTModule:Forward" )
246
+ if self .profiling_enabled
247
+ else nullcontext ()
248
+ ):
259
249
self ._check_initialized ()
250
+
260
251
cudagraphs_enabled = (
261
252
torch_tensorrt .runtime .get_cudagraphs_mode ()
262
- and not self .cudagraphs_parent_module
253
+ and not self .cudagraphs_enabled_parent_module
263
254
)
264
- shape_changed = self .validate_input_shapes (inputs )
265
- # Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
255
+ # Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
266
256
if not self .cudagraphs_enabled and cudagraphs_enabled :
267
257
need_cudagraphs_record = True
268
258
else :
269
- need_cudagraphs_record = cudagraphs_enabled and shape_changed
259
+ need_cudagraphs_record = (
260
+ cudagraphs_enabled and not self .cudagraphs_validate_shapes (inputs )
261
+ )
270
262
self .cudagraphs_enabled = cudagraphs_enabled
271
263
272
264
if need_cudagraphs_record :
273
- if self .cudagraph :
274
- self .cudagraph .reset ()
275
265
self ._input_buffers = [None ] * len (self .input_names )
276
266
self ._output_buffers = [None ] * len (self .output_names )
277
267
@@ -311,7 +301,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
311
301
]
312
302
logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
313
303
314
- with nvtx .annotate ("ProcessInputs" , color = "red" ):
304
+ with (
305
+ torch .autograd .profiler .record_function (
306
+ "PythonTorchTensorRTModule:ProcessInputs"
307
+ )
308
+ if self .profiling_enabled
309
+ else nullcontext ()
310
+ ):
315
311
assert len (contiguous_inputs ) == len (
316
312
self .input_names
317
313
), f"Wrong number of inputs, expect { len (self .input_names )} get { len (contiguous_inputs )} ."
@@ -364,32 +360,44 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
364
360
self .context .set_tensor_address (
365
361
input_name , contiguous_inputs [i ].data_ptr ()
366
362
)
367
- if shape_changed :
368
- # Check if input shapes can be inferred.
369
- uninferred_input_names = self .context .infer_shapes ()
370
- if uninferred_input_names :
371
- logger .warning (
372
- f"The shapes of the inputs: { uninferred_input_names } cannot be inferred and could lead to undefined behavior. \
373
- This could happen if the input tensor addresses/shapes haven't been configured correctly"
374
- )
375
363
376
- with nvtx .annotate ("ProcessOutputs:1" , color = "red" ):
377
- if not self .use_pre_allocated_outputs or shape_changed :
378
- self .output_shapes = [
379
- tuple (self .context .get_tensor_shape (output_name ))
380
- for output_name in self .output_names
381
- ]
382
- if DYNAMIC_DIM in self .output_shapes :
364
+ # Check if input shapes can be inferred.
365
+ uninferred_input_names = self .context .infer_shapes ()
366
+ if uninferred_input_names :
367
+ logger .warning (
368
+ f"The shapes of the inputs: { uninferred_input_names } cannot be inferred and could lead to undefined behavior. \
369
+ This could happen if the input tensor addresses/shapes haven't been configured correctly"
370
+ )
371
+
372
+ with (
373
+ torch .autograd .profiler .record_function (
374
+ "PythonTorchTensorRTModule:ProcessOutputs"
375
+ )
376
+ if self .profiling_enabled
377
+ else nullcontext ()
378
+ ):
379
+ # create output tensors
380
+ outputs : List [torch .Tensor ] = []
381
+
382
+ for o , output_name in enumerate (self .output_names ):
383
+ shape = tuple (self .context .get_tensor_shape (output_name ))
384
+
385
+ if DYNAMIC_DIM in shape :
383
386
raise ValueError (
384
387
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
385
388
)
386
- outputs = self .create_output_tensors ()
387
- else :
388
- outputs = self .pre_allocated_outputs
389
389
390
- for o , output_name in enumerate (self .output_names ):
390
+ output = torch .empty (
391
+ size = shape ,
392
+ dtype = self .output_dtypes [o ],
393
+ device = torch .cuda .current_device (),
394
+ )
395
+
396
+ outputs .append (output )
397
+
391
398
if need_cudagraphs_record :
392
399
self ._output_buffers [o ] = outputs [o ].clone ()
400
+
393
401
if cudagraphs_enabled :
394
402
self .context .set_tensor_address (
395
403
output_name , self ._output_buffers [o ].data_ptr ()
@@ -399,35 +407,37 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
399
407
output_name , outputs [o ].data_ptr ()
400
408
)
401
409
402
- with nvtx .annotate ("TensorRTRuntime" , color = "red" ):
410
+ with (
411
+ torch .autograd .profiler .record_function (
412
+ "PythonTorchTensorRTModule:TensorRTRuntime"
413
+ )
414
+ if self .profiling_enabled
415
+ else nullcontext ()
416
+ ):
403
417
self ._caller_stream = torch .cuda .current_stream ()
404
418
if (
405
419
self ._engine_stream == torch .cuda .default_stream ()
406
420
or self ._engine_stream is None
407
421
):
408
422
self ._engine_stream = torch .cuda .Stream ()
409
423
410
- with nvtx .annotate ("wait_stream" , color = "green" ):
411
- self ._engine_stream .wait_stream (self ._caller_stream )
424
+ self ._engine_stream .wait_stream (self ._caller_stream )
412
425
413
426
with torch .cuda .stream (self ._engine_stream ):
427
+
414
428
if cudagraphs_enabled :
415
429
if need_cudagraphs_record :
416
- with nvtx .annotate ("CUDAGraph" , color = "green" ):
417
- self .cudagraph = torch .cuda .CUDAGraph ()
430
+ self .cudagraph = torch .cuda .CUDAGraph ()
418
431
419
432
if self .profiling_enabled :
420
433
self .cudagraph .enable_debug_mode ()
421
- with nvtx .annotate ("torch.cuda.graph" , color = "green" ):
422
- with torch .cuda .graph (
423
- self .cudagraph , stream = self ._engine_stream
424
- ):
425
- with nvtx .annotate (
426
- "execute_async_v3" , color = "green"
427
- ):
428
- self .context .execute_async_v3 (
429
- self ._engine_stream .cuda_stream
430
- )
434
+
435
+ with torch .cuda .graph (
436
+ self .cudagraph , stream = self ._engine_stream
437
+ ):
438
+ self .context .execute_async_v3 (
439
+ self ._engine_stream .cuda_stream
440
+ )
431
441
432
442
if self .profiling_enabled :
433
443
import tempfile
@@ -436,18 +446,14 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
436
446
self .cudagraph .debug_dump (
437
447
f"{ tempdir } /{ self .name } _cudagraph.dot"
438
448
)
439
- with nvtx . annotate ( "replay" , color = "green" ):
440
- self .cudagraph .replay () # type: ignore
449
+
450
+ self .cudagraph .replay () # type: ignore
441
451
442
452
else :
443
453
self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
444
454
445
455
self ._caller_stream .wait_stream (self ._engine_stream )
446
456
447
- if self .use_pre_allocated_outputs :
448
- with nvtx .annotate ("ProcessOutputs:2" , color = "red" ):
449
- self .pre_allocated_outputs = self .create_output_tensors ()
450
-
451
457
if cudagraphs_enabled :
452
458
for idx , o in enumerate (outputs ):
453
459
o .copy_ (self ._output_buffers [idx ])
@@ -489,9 +495,10 @@ def get_layer_info(self) -> str:
489
495
)
490
496
return engine_json
491
497
492
- def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
498
+ def cudagraphs_validate_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
493
499
"""
494
- Validates the input shapes of the forward function has changed
500
+ Validates the input shapes of the forward function
501
+ versus the version currently active for the
495
502
"""
496
503
# Representation of input shapes to a given model
497
504
# Shapes are concatenated as so:
@@ -501,8 +508,10 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
501
508
# If the new shape key differs from the existing one,
502
509
# invalidate the old shape key and remove the CUDAGraph
503
510
if new_shape_key != self .shape_key :
504
- logger .debug (f"Input shape changed { self . shape_key } -> { new_shape_key } " )
511
+ logger .debug (f"Resetting Cudagraph on new shape key { new_shape_key } " )
505
512
self .shape_key = new_shape_key
506
- return True
513
+ if self .cudagraph :
514
+ self .cudagraph .reset ()
515
+ return False
507
516
508
- return False
517
+ return True
0 commit comments