Skip to content

Commit 4f464ef

Browse files
committed
chore: cleanup in WrapperTorchTensorRTModule
1 parent 817f9f9 commit 4f464ef

File tree

2 files changed

+54
-30
lines changed

2 files changed

+54
-30
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
838838
if len(trt_modules) > 1:
839839
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
840840
partitioned_module = WrapperTorchTensorRTModule(
841-
partitioned_module, dryrun_tracker.output_dtypes
841+
partitioned_module,
842+
dryrun_tracker.output_shapes,
843+
dryrun_tracker.output_dtypes,
842844
)
843845

844846
return partitioned_module

py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from __future__ import annotations
22

33
import logging
4+
from contextlib import nullcontext
45
from tempfile import tempdir
56
from typing import List, Optional, Sequence, Tuple
67

7-
import nvtx
88
import torch
99
import torch_tensorrt
1010
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1111
from torch_tensorrt.dynamo import partitioning
1212
from torch_tensorrt.dynamo.conversion import DYNAMIC_DIM
13+
from torch_tensorrt.dynamo.utils import input_is_dynamic
1314
from torch_tensorrt.runtime._utils import _is_switch_required, _select_rt_device
1415

1516
logger = logging.getLogger(__name__)
@@ -21,12 +22,13 @@ class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
2122
def __init__(
2223
self,
2324
original_module: torch.nn.Module,
25+
output_shapes: List[torch.Size],
2426
output_dtypes: List[torch.dtype],
2527
):
2628
super(WrapperTorchTensorRTModule, self).__init__()
2729
self.original_module = original_module
2830
self.inputs = partitioning.construct_submodule_inputs(original_module)
29-
self.output_shapes: List[torch.Tensor] = []
31+
self.output_shapes = output_shapes
3032
self.output_dtypes = output_dtypes
3133

3234
self._input_buffers: List[torch.Tensor] = []
@@ -37,6 +39,7 @@ def __init__(
3739
self.cudagraphs_enabled = False
3840
self._caller_stream: Optional[torch.cuda.Stream] = None
3941
self._engine_stream: Optional[torch.cuda.Stream] = None
42+
self.input_is_dynamic = input_is_dynamic(self.inputs)
4043

4144
# Disable cudagrphs in submodules as it will be enabled in wrapper
4245
for name, rt_mod in self.original_module.named_children():
@@ -67,11 +70,12 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
6770
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
6871
self.shape_key = new_shape_key
6972

70-
# TODO: avoid it for static input shape
71-
outputs = self.original_module(*inputs)
72-
if not isinstance(outputs, (list, tuple)):
73-
outputs = [outputs]
74-
self.output_shapes = [tuple(output.shape) for output in outputs]
73+
if self.input_is_dynamic:
74+
tmp_outputs = self.original_module(*inputs)
75+
if not isinstance(tmp_outputs, (list, tuple)):
76+
tmp_outputs = [tmp_outputs]
77+
self.output_shapes = [tuple(output.shape) for output in tmp_outputs]
78+
7579
return True
7680

7781
return False
@@ -86,8 +90,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
8690
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
8791
for i in inputs
8892
]
89-
with nvtx.annotate("Wrapper:Forward", color="orange"):
90-
93+
with (
94+
torch.autograd.profiler.record_function(
95+
"WrapperTorchTensorRTModule:Forward"
96+
)
97+
if self.profiling_enabled
98+
else nullcontext()
99+
):
91100
shape_changed = self.validate_input_shapes(inputs)
92101
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
93102
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
@@ -100,6 +109,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
100109
if need_cudagraphs_record:
101110
if self.cudagraph:
102111
self.cudagraph.reset()
112+
103113
self._input_buffers = [None] * len(self.inputs)
104114
self._output_buffers = [None] * len(self.output_shapes)
105115

@@ -139,15 +149,21 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
139149
]
140150
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
141151

142-
with nvtx.annotate("Wrapper:ProcessInputs", color="orange"):
152+
with (
153+
torch.autograd.profiler.record_function(
154+
"WrapperTorchTensorRTModule:ProcessInputs"
155+
)
156+
if self.profiling_enabled
157+
else nullcontext()
158+
):
143159
assert len(contiguous_inputs) == len(
144160
self.inputs
145161
), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}."
146162

147-
for i, input_name in enumerate(self.inputs):
163+
for i, _ in enumerate(self.inputs):
148164
if not contiguous_inputs[i].is_cuda:
149165
logger.warning(
150-
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
166+
f"Detected input[{i}] of engine {self.engine.name} is not on a cuda device. "
151167
"This tensor is being moved by the runtime but for performance considerations, "
152168
"ensure your inputs are all on GPU and open an issue here "
153169
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
@@ -169,7 +185,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
169185
elif cudagraphs_enabled:
170186
self._input_buffers[i].copy_(contiguous_inputs[i])
171187

172-
with nvtx.annotate("ProcessOutputs", color="red"):
188+
with (
189+
torch.autograd.profiler.record_function(
190+
"WrapperTorchTensorRTModule:ProcessOutputs"
191+
)
192+
if self.profiling_enabled
193+
else nullcontext()
194+
):
173195
# create output tensors
174196
outputs: List[torch.Tensor] = []
175197

@@ -189,34 +211,35 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
189211

190212
if need_cudagraphs_record:
191213
self._output_buffers[o] = outputs[o].clone()
192-
193-
with nvtx.annotate("Wrapper:TensorRTRuntime", color="orange"):
214+
with (
215+
torch.autograd.profiler.record_function(
216+
"WrapperTorchTensorRTModule:TensorRTRuntime"
217+
)
218+
if self.profiling_enabled
219+
else nullcontext()
220+
):
194221
self._caller_stream = torch.cuda.current_stream()
195222
if (
196223
self._engine_stream == torch.cuda.default_stream()
197224
or self._engine_stream is None
198225
):
199226
self._engine_stream = torch.cuda.Stream()
200227

201-
with nvtx.annotate("wait_stream", color="green"):
202-
self._engine_stream.wait_stream(self._caller_stream)
228+
self._engine_stream.wait_stream(self._caller_stream)
203229

204230
with torch.cuda.stream(self._engine_stream):
205231
if cudagraphs_enabled:
206232
if need_cudagraphs_record:
207-
with nvtx.annotate("CUDAGraph", color="green"):
208-
self.cudagraph = torch.cuda.CUDAGraph()
233+
self.cudagraph = torch.cuda.CUDAGraph()
209234

210235
if self.profiling_enabled:
211236
self.cudagraph.enable_debug_mode()
212-
with nvtx.annotate("torch.cuda.graph", color="green"):
213-
with torch.cuda.graph(
214-
self.cudagraph, stream=self._engine_stream
215-
):
216-
with nvtx.annotate("record", color="green"):
217-
self._output_buffers = self.original_module(
218-
*self._input_buffers
219-
)
237+
with torch.cuda.graph(
238+
self.cudagraph, stream=self._engine_stream
239+
):
240+
self._output_buffers = self.original_module(
241+
*self._input_buffers
242+
)
220243

221244
if self.profiling_enabled:
222245
import tempfile
@@ -225,8 +248,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
225248
self.cudagraph.debug_dump(
226249
f"{tempdir}/{self.name}_cudagraph.dot"
227250
)
228-
with nvtx.annotate("replay", color="green"):
229-
self.cudagraph.replay() # type: ignore
251+
self.cudagraph.replay() # type: ignore
230252

231253
else:
232254
outputs = self.original_module(*inputs)

0 commit comments

Comments
 (0)