1
1
from __future__ import annotations
2
2
3
3
import logging
4
+ from contextlib import nullcontext
4
5
from tempfile import tempdir
5
6
from typing import List , Optional , Sequence , Tuple
6
7
7
- import nvtx
8
8
import torch
9
9
import torch_tensorrt
10
10
from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
11
11
from torch_tensorrt .dynamo import partitioning
12
12
from torch_tensorrt .dynamo .conversion import DYNAMIC_DIM
13
+ from torch_tensorrt .dynamo .utils import input_is_dynamic
13
14
from torch_tensorrt .runtime ._utils import _is_switch_required , _select_rt_device
14
15
15
16
logger = logging .getLogger (__name__ )
@@ -21,12 +22,13 @@ class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
21
22
def __init__ (
22
23
self ,
23
24
original_module : torch .nn .Module ,
25
+ output_shapes : List [torch .Size ],
24
26
output_dtypes : List [torch .dtype ],
25
27
):
26
28
super (WrapperTorchTensorRTModule , self ).__init__ ()
27
29
self .original_module = original_module
28
30
self .inputs = partitioning .construct_submodule_inputs (original_module )
29
- self .output_shapes : List [ torch . Tensor ] = []
31
+ self .output_shapes = output_shapes
30
32
self .output_dtypes = output_dtypes
31
33
32
34
self ._input_buffers : List [torch .Tensor ] = []
@@ -37,6 +39,7 @@ def __init__(
37
39
self .cudagraphs_enabled = False
38
40
self ._caller_stream : Optional [torch .cuda .Stream ] = None
39
41
self ._engine_stream : Optional [torch .cuda .Stream ] = None
42
+ self .input_is_dynamic = input_is_dynamic (self .inputs )
40
43
41
44
# Disable cudagrphs in submodules as it will be enabled in wrapper
42
45
for name , rt_mod in self .original_module .named_children ():
@@ -67,11 +70,12 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
67
70
logger .debug (f"Input shape changed { self .shape_key } -> { new_shape_key } " )
68
71
self .shape_key = new_shape_key
69
72
70
- # TODO: avoid it for static input shape
71
- outputs = self .original_module (* inputs )
72
- if not isinstance (outputs , (list , tuple )):
73
- outputs = [outputs ]
74
- self .output_shapes = [tuple (output .shape ) for output in outputs ]
73
+ if self .input_is_dynamic :
74
+ tmp_outputs = self .original_module (* inputs )
75
+ if not isinstance (tmp_outputs , (list , tuple )):
76
+ tmp_outputs = [tmp_outputs ]
77
+ self .output_shapes = [tuple (output .shape ) for output in tmp_outputs ]
78
+
75
79
return True
76
80
77
81
return False
@@ -86,8 +90,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
86
90
(i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
87
91
for i in inputs
88
92
]
89
- with nvtx .annotate ("Wrapper:Forward" , color = "orange" ):
90
-
93
+ with (
94
+ torch .autograd .profiler .record_function (
95
+ "WrapperTorchTensorRTModule:Forward"
96
+ )
97
+ if self .profiling_enabled
98
+ else nullcontext ()
99
+ ):
91
100
shape_changed = self .validate_input_shapes (inputs )
92
101
cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
93
102
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
@@ -100,6 +109,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
100
109
if need_cudagraphs_record :
101
110
if self .cudagraph :
102
111
self .cudagraph .reset ()
112
+
103
113
self ._input_buffers = [None ] * len (self .inputs )
104
114
self ._output_buffers = [None ] * len (self .output_shapes )
105
115
@@ -139,15 +149,21 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
139
149
]
140
150
logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
141
151
142
- with nvtx .annotate ("Wrapper:ProcessInputs" , color = "orange" ):
152
+ with (
153
+ torch .autograd .profiler .record_function (
154
+ "WrapperTorchTensorRTModule:ProcessInputs"
155
+ )
156
+ if self .profiling_enabled
157
+ else nullcontext ()
158
+ ):
143
159
assert len (contiguous_inputs ) == len (
144
160
self .inputs
145
161
), f"Wrong number of inputs, expect { len (self .inputs )} get { len (contiguous_inputs )} ."
146
162
147
- for i , input_name in enumerate (self .inputs ):
163
+ for i , _ in enumerate (self .inputs ):
148
164
if not contiguous_inputs [i ].is_cuda :
149
165
logger .warning (
150
- f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
166
+ f"Detected input[ { i } ] of engine { self .engine .name } is not on a cuda device. "
151
167
"This tensor is being moved by the runtime but for performance considerations, "
152
168
"ensure your inputs are all on GPU and open an issue here "
153
169
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
@@ -169,7 +185,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
169
185
elif cudagraphs_enabled :
170
186
self ._input_buffers [i ].copy_ (contiguous_inputs [i ])
171
187
172
- with nvtx .annotate ("ProcessOutputs" , color = "red" ):
188
+ with (
189
+ torch .autograd .profiler .record_function (
190
+ "WrapperTorchTensorRTModule:ProcessOutputs"
191
+ )
192
+ if self .profiling_enabled
193
+ else nullcontext ()
194
+ ):
173
195
# create output tensors
174
196
outputs : List [torch .Tensor ] = []
175
197
@@ -189,34 +211,35 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
189
211
190
212
if need_cudagraphs_record :
191
213
self ._output_buffers [o ] = outputs [o ].clone ()
192
-
193
- with nvtx .annotate ("Wrapper:TensorRTRuntime" , color = "orange" ):
214
+ with (
215
+ torch .autograd .profiler .record_function (
216
+ "WrapperTorchTensorRTModule:TensorRTRuntime"
217
+ )
218
+ if self .profiling_enabled
219
+ else nullcontext ()
220
+ ):
194
221
self ._caller_stream = torch .cuda .current_stream ()
195
222
if (
196
223
self ._engine_stream == torch .cuda .default_stream ()
197
224
or self ._engine_stream is None
198
225
):
199
226
self ._engine_stream = torch .cuda .Stream ()
200
227
201
- with nvtx .annotate ("wait_stream" , color = "green" ):
202
- self ._engine_stream .wait_stream (self ._caller_stream )
228
+ self ._engine_stream .wait_stream (self ._caller_stream )
203
229
204
230
with torch .cuda .stream (self ._engine_stream ):
205
231
if cudagraphs_enabled :
206
232
if need_cudagraphs_record :
207
- with nvtx .annotate ("CUDAGraph" , color = "green" ):
208
- self .cudagraph = torch .cuda .CUDAGraph ()
233
+ self .cudagraph = torch .cuda .CUDAGraph ()
209
234
210
235
if self .profiling_enabled :
211
236
self .cudagraph .enable_debug_mode ()
212
- with nvtx .annotate ("torch.cuda.graph" , color = "green" ):
213
- with torch .cuda .graph (
214
- self .cudagraph , stream = self ._engine_stream
215
- ):
216
- with nvtx .annotate ("record" , color = "green" ):
217
- self ._output_buffers = self .original_module (
218
- * self ._input_buffers
219
- )
237
+ with torch .cuda .graph (
238
+ self .cudagraph , stream = self ._engine_stream
239
+ ):
240
+ self ._output_buffers = self .original_module (
241
+ * self ._input_buffers
242
+ )
220
243
221
244
if self .profiling_enabled :
222
245
import tempfile
@@ -225,8 +248,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
225
248
self .cudagraph .debug_dump (
226
249
f"{ tempdir } /{ self .name } _cudagraph.dot"
227
250
)
228
- with nvtx .annotate ("replay" , color = "green" ):
229
- self .cudagraph .replay () # type: ignore
251
+ self .cudagraph .replay () # type: ignore
230
252
231
253
else :
232
254
outputs = self .original_module (* inputs )
0 commit comments