5
5
from tempfile import tempdir
6
6
from typing import Any , Dict , List , Optional , Sequence , Tuple
7
7
8
+ import nvtx
8
9
import tensorrt as trt
9
10
import torch
10
11
import torch_tensorrt
@@ -78,7 +79,6 @@ def __init__(
78
79
self .cudagraph : Optional [torch .cuda .CUDAGraph ] = None
79
80
self ._caller_stream : Optional [torch .cuda .Stream ] = None
80
81
self ._engine_stream : Optional [torch .cuda .Stream ] = None
81
-
82
82
# TODO: Make the below a Dictionary {shape: cudagraph}
83
83
self .shape_key : Optional [str ] = None
84
84
@@ -107,6 +107,7 @@ def __init__(
107
107
self .engine = None
108
108
self .weight_name_map = weight_name_map
109
109
self .target_platform = Platform .current_platform ()
110
+ self .cudagraphs_disabled = False
110
111
111
112
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
112
113
self .setup_engine ()
@@ -238,15 +239,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
238
239
(i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
239
240
for i in inputs
240
241
]
241
-
242
- with (
243
- torch .autograd .profiler .record_function ("PythonTorchTensorRTModule:Forward" )
244
- if self .profiling_enabled
245
- else nullcontext ()
246
- ):
242
+ with nvtx .annotate (f"Forward" , color = "red" ):
247
243
self ._check_initialized ()
244
+ cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode () and not self .cudagraphs_disabled
248
245
249
- cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
250
246
need_cudagraphs_record = (
251
247
cudagraphs_enabled and not self .cudagraphs_validate_shapes (inputs )
252
248
)
@@ -291,13 +287,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
291
287
]
292
288
logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
293
289
294
- with (
295
- torch .autograd .profiler .record_function (
296
- "PythonTorchTensorRTModule:ProcessInputs"
297
- )
298
- if self .profiling_enabled
299
- else nullcontext ()
300
- ):
290
+ with nvtx .annotate (f"ProcessInputs" , color = "red" ):
301
291
assert len (contiguous_inputs ) == len (
302
292
self .input_names
303
293
), f"Wrong number of inputs, expect { len (self .input_names )} get { len (contiguous_inputs )} ."
@@ -359,13 +349,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
359
349
This could happen if the input tensor addresses/shapes haven't been configured correctly"
360
350
)
361
351
362
- with (
363
- torch .autograd .profiler .record_function (
364
- "PythonTorchTensorRTModule:ProcessOutputs"
365
- )
366
- if self .profiling_enabled
367
- else nullcontext ()
368
- ):
352
+ with nvtx .annotate (f"ProcessOutputs" , color = "red" ):
369
353
# create output tensors
370
354
outputs : List [torch .Tensor ] = []
371
355
@@ -397,37 +381,35 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
397
381
output_name , outputs [o ].data_ptr ()
398
382
)
399
383
400
- with (
401
- torch .autograd .profiler .record_function (
402
- "PythonTorchTensorRTModule:TensorRTRuntime"
403
- )
404
- if self .profiling_enabled
405
- else nullcontext ()
406
- ):
384
+ with nvtx .annotate (f"TensorRTRuntime" , color = "red" ):
407
385
self ._caller_stream = torch .cuda .current_stream ()
408
386
if (
409
387
self ._engine_stream == torch .cuda .default_stream ()
410
388
or self ._engine_stream is None
411
389
):
412
390
self ._engine_stream = torch .cuda .Stream ()
413
391
414
- self ._engine_stream .wait_stream (self ._caller_stream )
392
+ with nvtx .annotate (f"wait_stream" , color = "green" ):
393
+ self ._engine_stream .wait_stream (self ._caller_stream )
415
394
416
395
with torch .cuda .stream (self ._engine_stream ):
417
-
418
396
if cudagraphs_enabled :
419
397
if need_cudagraphs_record :
420
- self .cudagraph = torch .cuda .CUDAGraph ()
398
+ with nvtx .annotate (f"CUDAGraph" , color = "green" ):
399
+ self .cudagraph = torch .cuda .CUDAGraph ()
421
400
422
401
if self .profiling_enabled :
423
402
self .cudagraph .enable_debug_mode ()
424
-
425
- with torch .cuda .graph (
426
- self .cudagraph , stream = self ._engine_stream
427
- ):
428
- self .context .execute_async_v3 (
429
- self ._engine_stream .cuda_stream
430
- )
403
+ with nvtx .annotate (f"torch.cuda.graph" , color = "green" ):
404
+ with torch .cuda .graph (
405
+ self .cudagraph , stream = self ._engine_stream
406
+ ):
407
+ with nvtx .annotate (
408
+ f"execute_async_v3" , color = "green"
409
+ ):
410
+ self .context .execute_async_v3 (
411
+ self ._engine_stream .cuda_stream
412
+ )
431
413
432
414
if self .profiling_enabled :
433
415
import tempfile
@@ -436,8 +418,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
436
418
self .cudagraph .debug_dump (
437
419
f"{ tempdir } /{ self .name } _cudagraph.dot"
438
420
)
439
-
440
- self .cudagraph .replay () # type: ignore
421
+ with nvtx . annotate ( f"replay" , color = "green" ):
422
+ self .cudagraph .replay () # type: ignore
441
423
442
424
else :
443
425
self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
0 commit comments