1
1
from __future__ import annotations
2
2
3
3
import logging
4
+ from contextlib import nullcontext
4
5
from typing import Any , Dict , List , Optional , Sequence , Tuple
5
6
6
7
import tensorrt as trt
7
8
import torch
8
9
from torch .nn import Module
10
+ from torch_tensorrt ._Device import Device
11
+ from torch_tensorrt .dynamo .runtime .tools import _is_switch_required , _select_rt_device
9
12
from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
10
13
14
+ import torch_tensorrt
15
+
11
16
logger = logging .getLogger (__name__ )
12
17
13
18
@@ -23,13 +28,22 @@ def __init__(
23
28
engine : trt .ICudaEngine ,
24
29
input_names : Optional [List [str ]] = None ,
25
30
output_names : Optional [List [str ]] = None ,
31
+ target_device : Device = Device ._current_device (),
32
+ profiling_enabled : Optional [bool ] = None ,
26
33
):
27
34
super (PythonTorchTensorRTModule , self ).__init__ ()
28
35
self ._register_state_dict_hook (PythonTorchTensorRTModule ._on_state_dict )
29
36
self .engine = engine
30
37
self .input_names = input_names if input_names is not None else []
31
38
self .output_names = output_names if output_names is not None else []
32
39
self .initialized = False
40
+ self .target_device_id = target_device .gpu_id
41
+ self .target_device_properties = torch .cuda .get_device_properties (
42
+ self .target_device_id
43
+ )
44
+ self .profiling_enabled = (
45
+ profiling_enabled if profiling_enabled is not None else False
46
+ )
33
47
self ._initialize ()
34
48
35
49
def _initialize (self ) -> None :
@@ -141,15 +155,41 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
141
155
if self .engine :
142
156
self .context = self .engine .create_execution_context ()
143
157
144
- def forward (self , * inputs : Any ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
158
+ def forward (self , * inputs : torch . Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
145
159
with torch .autograd .profiler .record_function (
146
160
"PythonTorchTensorRTModule:Forward"
147
- ):
161
+ ) if self . profiling_enabled else nullcontext () :
148
162
self ._check_initialized ()
149
163
164
+ # If in safe mode, check at each iteration for for whether a switch is required
165
+ if torch_tensorrt ._compile .SAFE_MODE :
166
+ curr_device_id = torch .cuda .current_device ()
167
+ curr_device_properties = torch .cuda .get_device_properties (
168
+ curr_device_id
169
+ )
170
+ logger .debug (f"Current Device: cuda:{ curr_device_id } " )
171
+
172
+ # If a switch is required, move all inputs to new device and set as active device
173
+ if _is_switch_required (
174
+ curr_device_id ,
175
+ self .target_device_id ,
176
+ curr_device_properties ,
177
+ self .target_device_properties ,
178
+ ):
179
+ device_id , _ = _select_rt_device (
180
+ curr_device_id ,
181
+ self .target_device_id ,
182
+ self .target_device_properties ,
183
+ )
184
+ device = torch .device (device_id )
185
+ torch .cuda .set_device (device_id )
186
+
187
+ inputs = tuple ([tensor .to (device ) for tensor in inputs ])
188
+ logger .warning (f"Moved all input Tensors to cuda:{ device_id } " )
189
+
150
190
with torch .autograd .profiler .record_function (
151
191
"PythonTorchTensorRTModule:ProcessInputs"
152
- ):
192
+ ) if self . profiling_enabled else nullcontext () :
153
193
assert len (inputs ) == len (
154
194
self .input_names
155
195
), f"Wrong number of inputs, expect { len (self .input_names )} get { len (inputs )} ."
@@ -162,22 +202,24 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
162
202
)
163
203
164
204
for i , input_name in enumerate (self .input_names ):
165
- if not contiguous_inputs [i ].is_cuda :
166
- logger .warning (
167
- f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
168
- "This tensor is being moved by the runtime but for performance considerations, "
169
- "ensure your inputs are all on GPU and open an issue here "
170
- "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
171
- )
172
- contiguous_inputs = (
173
- contiguous_inputs [:i ]
174
- + [contiguous_inputs [i ].cuda ()]
175
- + contiguous_inputs [i + 1 :]
176
- )
177
-
178
- assert (
179
- contiguous_inputs [i ].dtype == self .input_dtypes [i ]
180
- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
205
+ # Check that the inputs are on cuda and have the correct data type if in safe mode
206
+ if torch_tensorrt ._compile .SAFE_MODE :
207
+ if not contiguous_inputs [i ].is_cuda :
208
+ logger .warning (
209
+ f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
210
+ "This tensor is being moved by the runtime but for performance considerations, "
211
+ "ensure your inputs are all on GPU and open an issue here "
212
+ "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
213
+ )
214
+ contiguous_inputs = (
215
+ contiguous_inputs [:i ]
216
+ + [contiguous_inputs [i ].cuda ()]
217
+ + contiguous_inputs [i + 1 :]
218
+ )
219
+
220
+ assert (
221
+ contiguous_inputs [i ].dtype == self .input_dtypes [i ]
222
+ ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
181
223
182
224
idx = self .input_binding_indices_in_order [i ]
183
225
bindings [idx ] = contiguous_inputs [i ].data_ptr ()
@@ -188,7 +230,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
188
230
189
231
with torch .autograd .profiler .record_function (
190
232
"PythonTorchTensorRTModule:ProcessOutputs"
191
- ):
233
+ ) if self . profiling_enabled else nullcontext () :
192
234
# create output tensors
193
235
outputs : List [torch .Tensor ] = []
194
236
@@ -215,7 +257,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
215
257
216
258
with torch .autograd .profiler .record_function (
217
259
"PythonTorchTensorRTModule:TensorRTRuntime"
218
- ):
260
+ ) if self . profiling_enabled else nullcontext () :
219
261
self .context .execute_async_v2 (
220
262
bindings , torch .cuda .current_stream ().cuda_stream
221
263
)
@@ -235,6 +277,8 @@ def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None:
235
277
if not self .context .profiler :
236
278
self .context .profiler = trt .Profiler () if profiler is None else profiler
237
279
280
+ self .profiling_enabled = True
281
+
238
282
def disable_profiling (self ) -> None :
239
283
"""
240
284
Disable TensorRT profiling.
@@ -244,6 +288,7 @@ def disable_profiling(self) -> None:
244
288
torch .cuda .synchronize ()
245
289
del self .context
246
290
self .context = self .engine .create_execution_context ()
291
+ self .profiling_enabled = False
247
292
248
293
def get_layer_info (self ) -> str :
249
294
"""
0 commit comments