Skip to content

Commit 92ab985

Browse files
committed
feat: Wrapped module to record/replay cudagraph in sub modules
1 parent 6d40ff1 commit 92ab985

File tree

4 files changed

+352
-41
lines changed

4 files changed

+352
-41
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
post_lowering,
3535
pre_export_lowering,
3636
)
37+
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
38+
WrapperTorchTensorRTModule,
39+
)
3740
from torch_tensorrt.dynamo.utils import (
3841
get_flat_args_with_check,
3942
parse_graph_io,
@@ -516,6 +519,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
516519

517520
dryrun_stats_display(dryrun_tracker, settings.dryrun)
518521

522+
if len(trt_modules) > 1:
523+
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
524+
partitioned_module = WrapperTorchTensorRTModule(
525+
partitioned_module, dryrun_tracker.output_dtypes
526+
)
527+
519528
return partitioned_module
520529

521530

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tempfile import tempdir
66
from typing import Any, Dict, List, Optional, Sequence, Tuple
77

8+
import nvtx
89
import tensorrt as trt
910
import torch
1011
import torch_tensorrt
@@ -78,7 +79,6 @@ def __init__(
7879
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
7980
self._caller_stream: Optional[torch.cuda.Stream] = None
8081
self._engine_stream: Optional[torch.cuda.Stream] = None
81-
8282
# TODO: Make the below a Dictionary {shape: cudagraph}
8383
self.shape_key: Optional[str] = None
8484

@@ -107,6 +107,7 @@ def __init__(
107107
self.engine = None
108108
self.weight_name_map = weight_name_map
109109
self.target_platform = Platform.current_platform()
110+
self.cudagraphs_disabled = False
110111

111112
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
112113
self.setup_engine()
@@ -238,15 +239,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
238239
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
239240
for i in inputs
240241
]
241-
242-
with (
243-
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
244-
if self.profiling_enabled
245-
else nullcontext()
246-
):
242+
with nvtx.annotate(f"Forward", color="red"):
247243
self._check_initialized()
244+
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() and not self.cudagraphs_disabled
248245

249-
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
250246
need_cudagraphs_record = (
251247
cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs)
252248
)
@@ -291,13 +287,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
291287
]
292288
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
293289

294-
with (
295-
torch.autograd.profiler.record_function(
296-
"PythonTorchTensorRTModule:ProcessInputs"
297-
)
298-
if self.profiling_enabled
299-
else nullcontext()
300-
):
290+
with nvtx.annotate(f"ProcessInputs", color="red"):
301291
assert len(contiguous_inputs) == len(
302292
self.input_names
303293
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
@@ -359,13 +349,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
359349
This could happen if the input tensor addresses/shapes haven't been configured correctly"
360350
)
361351

362-
with (
363-
torch.autograd.profiler.record_function(
364-
"PythonTorchTensorRTModule:ProcessOutputs"
365-
)
366-
if self.profiling_enabled
367-
else nullcontext()
368-
):
352+
with nvtx.annotate(f"ProcessOutputs", color="red"):
369353
# create output tensors
370354
outputs: List[torch.Tensor] = []
371355

@@ -397,37 +381,35 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
397381
output_name, outputs[o].data_ptr()
398382
)
399383

400-
with (
401-
torch.autograd.profiler.record_function(
402-
"PythonTorchTensorRTModule:TensorRTRuntime"
403-
)
404-
if self.profiling_enabled
405-
else nullcontext()
406-
):
384+
with nvtx.annotate(f"TensorRTRuntime", color="red"):
407385
self._caller_stream = torch.cuda.current_stream()
408386
if (
409387
self._engine_stream == torch.cuda.default_stream()
410388
or self._engine_stream is None
411389
):
412390
self._engine_stream = torch.cuda.Stream()
413391

414-
self._engine_stream.wait_stream(self._caller_stream)
392+
with nvtx.annotate(f"wait_stream", color="green"):
393+
self._engine_stream.wait_stream(self._caller_stream)
415394

416395
with torch.cuda.stream(self._engine_stream):
417-
418396
if cudagraphs_enabled:
419397
if need_cudagraphs_record:
420-
self.cudagraph = torch.cuda.CUDAGraph()
398+
with nvtx.annotate(f"CUDAGraph", color="green"):
399+
self.cudagraph = torch.cuda.CUDAGraph()
421400

422401
if self.profiling_enabled:
423402
self.cudagraph.enable_debug_mode()
424-
425-
with torch.cuda.graph(
426-
self.cudagraph, stream=self._engine_stream
427-
):
428-
self.context.execute_async_v3(
429-
self._engine_stream.cuda_stream
430-
)
403+
with nvtx.annotate(f"torch.cuda.graph", color="green"):
404+
with torch.cuda.graph(
405+
self.cudagraph, stream=self._engine_stream
406+
):
407+
with nvtx.annotate(
408+
f"execute_async_v3", color="green"
409+
):
410+
self.context.execute_async_v3(
411+
self._engine_stream.cuda_stream
412+
)
431413

432414
if self.profiling_enabled:
433415
import tempfile
@@ -436,8 +418,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
436418
self.cudagraph.debug_dump(
437419
f"{tempdir}/{self.name}_cudagraph.dot"
438420
)
439-
440-
self.cudagraph.replay() # type: ignore
421+
with nvtx.annotate(f"replay", color="green"):
422+
self.cudagraph.replay() # type: ignore
441423

442424
else:
443425
self.context.execute_async_v3(self._engine_stream.cuda_stream)

0 commit comments

Comments
 (0)