Skip to content

Commit 09b01e0

Browse files
committed
feat: Add device checks in Python Runtime
- Add support for device checks in the Python Runtime, to mirror those in the C++ runtime - Fix various issues in partitioning and runtime accordingly
1 parent 91c93f5 commit 09b01e0

File tree

9 files changed

+251
-41
lines changed

9 files changed

+251
-41
lines changed

core/runtime/execute_engine.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,15 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
129129
}
130130
for (size_t i = 0; i < inputs.size(); i++) {
131131
std::string name = compiled_engine->in_binding_names[i];
132-
TORCHTRT_CHECK(
133-
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
134-
auto expected_type =
135-
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
136-
TORCHTRT_CHECK(
137-
inputs[i].dtype() == expected_type,
138-
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
132+
if (SAFE_MODE) {
133+
TORCHTRT_CHECK(
134+
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
135+
auto expected_type =
136+
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
137+
TORCHTRT_CHECK(
138+
inputs[i].dtype() == expected_type,
139+
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
140+
}
139141
auto dims = core::util::toDims(inputs[i].sizes());
140142
auto shape = core::util::toVec(dims);
141143
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);

core/runtime/runtime.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
3333
if (device.device_name == target_device.device_name) {
3434
// First priority is selecting a candidate which agrees with the current device ID
3535
// If such a device is found, we can select it and break out of the loop
36-
if (device.id == current_device.id && best_match.id != current_device.id) {
36+
if (device.id == current_device.id) {
3737
best_match = device;
3838
break;
3939
}
4040
// Second priority is selecting a candidate which agrees with the target device ID
4141
// At deserialization time, the current device and target device may not agree
42-
else if (device.id == target_device.id && best_match.id != target_device.id) {
42+
else if (device.id == target_device.id) {
4343
best_match = device;
4444
}
4545
// If no such GPU ID is found, select the first available candidate GPU

py/torch_tensorrt/_compile.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
DYNAMO_ENABLED = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
2222

23+
SAFE_MODE = torch.ops.tensorrt.get_safe_mode()
24+
2325
if DYNAMO_ENABLED:
2426
from torch._export import ExportedProgram
2527
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
@@ -256,18 +258,22 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
256258
return boxed_fn
257259

258260

259-
def enable_unsafe_inference_mode():
261+
def enable_unsafe_inference_mode() -> None:
260262
"""
261263
Enables unsafe inference mode for Torch-TensorRT
262264
"""
265+
global SAFE_MODE
266+
SAFE_MODE = False
263267
torch.ops.tensorrt.set_safe_mode(False)
264268
logger.info("Enabled unsafe inference mode")
265269

266270

267-
def enable_safe_inference_mode():
271+
def enable_safe_inference_mode() -> None:
268272
"""
269273
Enables safe inference mode for Torch-TensorRT
270274
"""
275+
global SAFE_MODE
276+
SAFE_MODE = True
271277
torch.ops.tensorrt.set_safe_mode(True)
272278
logger.info("Enabled safe inference mode")
273279

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
import io
44
from typing import Sequence
55

6+
import tensorrt as trt
67
import torch
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo._settings import CompilationSettings
910
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
1011
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1112
from torch_tensorrt.dynamo.utils import get_torch_inputs
1213

13-
import tensorrt as trt
14-
1514

1615
def convert_module(
1716
module: torch.fx.GraphModule,
@@ -72,6 +71,8 @@ def convert_module(
7271
engine=interpreter_result.engine,
7372
input_names=list(interpreter_result.input_names),
7473
output_names=list(interpreter_result.output_names),
74+
target_device=settings.device,
75+
profiling_enabled=settings.debug,
7576
)
7677

7778
else:

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def is_node_supported(
4242
node_name = ConverterRegistry.qualified_name_or_str(node.target)
4343

4444
if (
45-
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
45+
node in CONVERTERS or node.op == "get_attr"
4646
) and node_name not in self.torch_executed_ops:
4747
# If node is a proper, supported computational node, store the operator
48-
if not node.is_impure():
48+
if not node.is_impure() and node.op != "get_attr":
4949
if node_name not in self.supported_operators:
5050
self.supported_operators[node_name] = 1
5151
else:

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ def is_node_supported(
150150
node_name = ConverterRegistry.qualified_name_or_str(node.target)
151151

152152
if (
153-
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
153+
node in CONVERTERS or node.op == "get_attr"
154154
) and node_name not in self.torch_executed_ops:
155155
# If node is a proper, supported computational node, store the operator
156-
if not node.is_impure():
156+
if not node.is_impure() and node.op != "get_attr":
157157
if node_name not in self.supported_operators:
158158
self.supported_operators[node_name] = 1
159159
else:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from __future__ import annotations
22

33
import logging
4+
from contextlib import nullcontext
45
from typing import Any, Dict, List, Optional, Sequence, Tuple
56

67
import tensorrt as trt
78
import torch
89
from torch.nn import Module
10+
from torch_tensorrt._Device import Device
11+
from torch_tensorrt.dynamo.runtime.tools import _is_switch_required, _select_rt_device
912
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1013

14+
import torch_tensorrt
15+
1116
logger = logging.getLogger(__name__)
1217

1318

@@ -23,13 +28,22 @@ def __init__(
2328
engine: trt.ICudaEngine,
2429
input_names: Optional[List[str]] = None,
2530
output_names: Optional[List[str]] = None,
31+
target_device: Device = Device._current_device(),
32+
profiling_enabled: Optional[bool] = None,
2633
):
2734
super(PythonTorchTensorRTModule, self).__init__()
2835
self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict)
2936
self.engine = engine
3037
self.input_names = input_names if input_names is not None else []
3138
self.output_names = output_names if output_names is not None else []
3239
self.initialized = False
40+
self.target_device_id = target_device.gpu_id
41+
self.target_device_properties = torch.cuda.get_device_properties(
42+
self.target_device_id
43+
)
44+
self.profiling_enabled = (
45+
profiling_enabled if profiling_enabled is not None else False
46+
)
3347
self._initialize()
3448

3549
def _initialize(self) -> None:
@@ -141,15 +155,41 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
141155
if self.engine:
142156
self.context = self.engine.create_execution_context()
143157

144-
def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
158+
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
145159
with torch.autograd.profiler.record_function(
146160
"PythonTorchTensorRTModule:Forward"
147-
):
161+
) if self.profiling_enabled else nullcontext():
148162
self._check_initialized()
149163

164+
# If in safe mode, check at each iteration for for whether a switch is required
165+
if torch_tensorrt._compile.SAFE_MODE:
166+
curr_device_id = torch.cuda.current_device()
167+
curr_device_properties = torch.cuda.get_device_properties(
168+
curr_device_id
169+
)
170+
logger.debug(f"Current Device: cuda:{curr_device_id}")
171+
172+
# If a switch is required, move all inputs to new device and set as active device
173+
if _is_switch_required(
174+
curr_device_id,
175+
self.target_device_id,
176+
curr_device_properties,
177+
self.target_device_properties,
178+
):
179+
device_id, _ = _select_rt_device(
180+
curr_device_id,
181+
self.target_device_id,
182+
self.target_device_properties,
183+
)
184+
device = torch.device(device_id)
185+
torch.cuda.set_device(device_id)
186+
187+
inputs = tuple([tensor.to(device) for tensor in inputs])
188+
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
189+
150190
with torch.autograd.profiler.record_function(
151191
"PythonTorchTensorRTModule:ProcessInputs"
152-
):
192+
) if self.profiling_enabled else nullcontext():
153193
assert len(inputs) == len(
154194
self.input_names
155195
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
@@ -162,22 +202,24 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
162202
)
163203

164204
for i, input_name in enumerate(self.input_names):
165-
if not contiguous_inputs[i].is_cuda:
166-
logger.warning(
167-
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
168-
"This tensor is being moved by the runtime but for performance considerations, "
169-
"ensure your inputs are all on GPU and open an issue here "
170-
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
171-
)
172-
contiguous_inputs = (
173-
contiguous_inputs[:i]
174-
+ [contiguous_inputs[i].cuda()]
175-
+ contiguous_inputs[i + 1 :]
176-
)
177-
178-
assert (
179-
contiguous_inputs[i].dtype == self.input_dtypes[i]
180-
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
205+
# Check that the inputs are on cuda and have the correct data type if in safe mode
206+
if torch_tensorrt._compile.SAFE_MODE:
207+
if not contiguous_inputs[i].is_cuda:
208+
logger.warning(
209+
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
210+
"This tensor is being moved by the runtime but for performance considerations, "
211+
"ensure your inputs are all on GPU and open an issue here "
212+
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
213+
)
214+
contiguous_inputs = (
215+
contiguous_inputs[:i]
216+
+ [contiguous_inputs[i].cuda()]
217+
+ contiguous_inputs[i + 1 :]
218+
)
219+
220+
assert (
221+
contiguous_inputs[i].dtype == self.input_dtypes[i]
222+
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
181223

182224
idx = self.input_binding_indices_in_order[i]
183225
bindings[idx] = contiguous_inputs[i].data_ptr()
@@ -188,7 +230,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
188230

189231
with torch.autograd.profiler.record_function(
190232
"PythonTorchTensorRTModule:ProcessOutputs"
191-
):
233+
) if self.profiling_enabled else nullcontext():
192234
# create output tensors
193235
outputs: List[torch.Tensor] = []
194236

@@ -215,7 +257,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
215257

216258
with torch.autograd.profiler.record_function(
217259
"PythonTorchTensorRTModule:TensorRTRuntime"
218-
):
260+
) if self.profiling_enabled else nullcontext():
219261
self.context.execute_async_v2(
220262
bindings, torch.cuda.current_stream().cuda_stream
221263
)
@@ -235,6 +277,8 @@ def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None:
235277
if not self.context.profiler:
236278
self.context.profiler = trt.Profiler() if profiler is None else profiler
237279

280+
self.profiling_enabled = True
281+
238282
def disable_profiling(self) -> None:
239283
"""
240284
Disable TensorRT profiling.
@@ -244,6 +288,7 @@ def disable_profiling(self) -> None:
244288
torch.cuda.synchronize()
245289
del self.context
246290
self.context = self.engine.create_execution_context()
291+
self.profiling_enabled = False
247292

248293
def get_layer_info(self) -> str:
249294
"""

0 commit comments

Comments
 (0)