Skip to content

Commit b434aee

Browse files
committed
feat: Add documentation and warnings for Safe Mode
1 parent 922c352 commit b434aee

File tree

13 files changed

+162
-60
lines changed

13 files changed

+162
-60
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ TRTEngine::TRTEngine(
5252
auto most_compatible_device = get_most_compatible_device(cuda_device);
5353
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
5454
device_info = most_compatible_device.value();
55+
multi_gpu_device_check(device_info);
5556
set_rt_device(device_info);
5657

5758
rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger()));

core/runtime/execute_engine.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
7474
LOG_INFO("" << log_info);
7575
}
7676

77-
if (SAFE_MODE) {
77+
if (MULTI_DEVICE_SAFE_MODE) {
7878
std::unique_ptr<torch::autograd::profiler::RecordProfile> device_profiler_guard;
7979
if (compiled_engine->profile_execution) {
8080
device_profiler_guard =
@@ -129,15 +129,13 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
129129
}
130130
for (size_t i = 0; i < inputs.size(); i++) {
131131
std::string name = compiled_engine->in_binding_names[i];
132-
if (SAFE_MODE) {
133-
TORCHTRT_CHECK(
134-
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
135-
auto expected_type =
136-
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
137-
TORCHTRT_CHECK(
138-
inputs[i].dtype() == expected_type,
139-
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
140-
}
132+
TORCHTRT_CHECK(
133+
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
134+
auto expected_type =
135+
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
136+
TORCHTRT_CHECK(
137+
inputs[i].dtype() == expected_type,
138+
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
141139
auto dims = core::util::toDims(inputs[i].sizes());
142140
auto shape = core::util::toVec(dims);
143141
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);

core/runtime/register_jit_hooks.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,10 @@ TORCH_LIBRARY(tensorrt, m) {
114114
m.def("execute_engine", execute_engine);
115115
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
116116
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
117-
m.def("get_safe_mode", []() -> bool { return SAFE_MODE; });
118-
m.def("set_safe_mode", [](bool safe_mode) -> void { SAFE_MODE = safe_mode; });
117+
m.def("get_multi_device_safe_mode", []() -> bool { return MULTI_DEVICE_SAFE_MODE; });
118+
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
119+
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
120+
});
119121
}
120122

121123
} // namespace

core/runtime/runtime.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace torch_tensorrt {
77
namespace core {
88
namespace runtime {
99

10-
bool SAFE_MODE = true;
10+
bool MULTI_DEVICE_SAFE_MODE = false;
1111

1212
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
1313
LOG_DEBUG("Target Device: " << target_device);
@@ -105,6 +105,19 @@ RTDevice get_current_device() {
105105
return RTDevice(device_id, nvinfer1::DeviceType::kGPU);
106106
}
107107

108+
void multi_gpu_device_check(const RTDevice& most_compatible_device) {
109+
// If multi-device safe mode is disabled and more than 1 device is registered on the machine, warn user
110+
if (!(MULTI_DEVICE_SAFE_MODE) && get_available_device_list().get_devices().size() > 1) {
111+
LOG_WARNING(
112+
"Detected this engine is being instantitated in a multi-GPU system with "
113+
<< "multi-device safe mode disabled. For more on the implications of this "
114+
<< "as well as workarounds, see MULTI_DEVICE_SAFE_MODE.md "
115+
<< "(https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/runtime/MULTI_DEVICE_SAFE_MODE.md). "
116+
<< "The engine is set to be instantiated on the cuda device, " << most_compatible_device << ". "
117+
<< "If this is incorrect, please set the desired cuda device as default and retry.");
118+
}
119+
}
120+
108121
namespace {
109122
static DeviceList cuda_device_list;
110123
}

core/runtime/runtime.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace runtime {
1616

1717
using EngineID = int64_t;
1818
const std::string ABI_VERSION = "4";
19-
extern bool SAFE_MODE;
19+
extern bool MULTI_DEVICE_SAFE_MODE;
2020
typedef enum {
2121
ABI_TARGET_IDX = 0,
2222
NAME_IDX,
@@ -34,6 +34,8 @@ std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device);
3434

3535
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
3636

37+
void multi_gpu_device_check(const RTDevice& most_compatible_device);
38+
3739
class DeviceList {
3840
using DeviceMap = std::unordered_map<int, RTDevice>;
3941
DeviceMap device_list;

py/torch_tensorrt/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,14 @@ def _find_lib(name: str, paths: List[str]) -> str:
8282

8383
import torch
8484
from torch_tensorrt._compile import * # noqa: F403
85-
from torch_tensorrt._compile import (
86-
enable_safe_inference_mode,
87-
enable_unsafe_inference_mode,
88-
)
8985
from torch_tensorrt._Device import Device # noqa: F401
9086
from torch_tensorrt._enums import * # noqa: F403
9187
from torch_tensorrt._Input import Input # noqa: F401
9288
from torch_tensorrt._utils import * # noqa: F403
9389
from torch_tensorrt._utils import sanitized_torch_version
9490
from torch_tensorrt.logging import *
9591
from torch_tensorrt.ptq import *
92+
from torch_tensorrt.runtime import * # noqa: F403
9693

9794
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
9895
from torch_tensorrt.dynamo import backend # noqa: F401

py/torch_tensorrt/_compile.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
DYNAMO_ENABLED = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
2222

23-
SAFE_MODE = torch.ops.tensorrt.get_safe_mode()
24-
2523
if DYNAMO_ENABLED:
2624
from torch._export import ExportedProgram
2725
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
@@ -258,26 +256,6 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
258256
return boxed_fn
259257

260258

261-
def enable_unsafe_inference_mode() -> None:
262-
"""
263-
Enables unsafe inference mode for Torch-TensorRT
264-
"""
265-
global SAFE_MODE
266-
SAFE_MODE = False
267-
torch.ops.tensorrt.set_safe_mode(False)
268-
logger.info("Enabled unsafe inference mode")
269-
270-
271-
def enable_safe_inference_mode() -> None:
272-
"""
273-
Enables safe inference mode for Torch-TensorRT
274-
"""
275-
global SAFE_MODE
276-
SAFE_MODE = True
277-
torch.ops.tensorrt.set_safe_mode(True)
278-
logger.info("Enabled safe inference mode")
279-
280-
281259
def convert_method_to_trt_engine(
282260
module: Any,
283261
method_name: str = "forward",
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
Multi-device safe mode is a setting in Torch-TensorRT which allows the user to determine whether
2+
the runtime checks for device consistency prior to every inference call.
3+
4+
There is a non-negligible, fixed cost per-inference call when multi-device safe mode, which is why
5+
it is now disabled by default. It can be controlled via the following convenience function which
6+
doubles as a context manager.
7+
```python
8+
# Enables Multi Device Safe Mode
9+
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
10+
11+
# Disables Multi Device Safe Mode [Default Behavior]
12+
torch_tensorrt.runtime.set_multi_device_safe_mode(False)
13+
14+
# Enables Multi Device Safe Mode, then resets the safe mode to its prior setting
15+
with torch_tensorrt.runtime.set_multi_device_safe_mode(True):
16+
...
17+
```
18+
TensorRT requires that each engine be associated with the CUDA context in the active thread from which it is invoked.
19+
Therefore, if the device were to change in the active thread, which may be the case when invoking
20+
engines on multiple GPUs from the same Python process, safe mode will cause Torch-TensorRT to display
21+
an alert and switch GPUs accordingly. If safe mode were not enabled, there could be a mismatch in the engine
22+
device and CUDA context device, which could lead the program to crash.
23+
24+
One technique for managing multiple TRT engines on different GPUs while not sacrificing performance for
25+
multi-device safe mode is to use Python threads. Each thread is responsible for all of the TRT engines
26+
on a single GPU, and the default CUDA device on each thread corresponds to the GPU for which it is
27+
responsible (can be set via `torch.cuda.set_device(...)`). In this way, multiple threads can be used in the same
28+
Python scripts without needing to switch CUDA contexts and incur performance overhead by leveraging threads.

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
import torch
99
from torch.nn import Module
1010
from torch_tensorrt._Device import Device
11-
from torch_tensorrt.dynamo.runtime.tools import _is_switch_required, _select_rt_device
11+
from torch_tensorrt.dynamo.runtime.tools import (
12+
_is_switch_required,
13+
_select_rt_device,
14+
multi_gpu_device_check,
15+
)
1216
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1317

1418
import torch_tensorrt
@@ -33,6 +37,10 @@ def __init__(
3337
):
3438
super(PythonTorchTensorRTModule, self).__init__()
3539
self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict)
40+
41+
# Run multi-gpu device check to validate engine instantiation
42+
multi_gpu_device_check()
43+
3644
self.engine = engine
3745
self.input_names = input_names if input_names is not None else []
3846
self.output_names = output_names if output_names is not None else []
@@ -133,6 +141,9 @@ def _load_from_state_dict(
133141
) -> None:
134142
engine_bytes = state_dict[prefix + "engine"]
135143

144+
# Run multi-gpu device check to validate engine instantiation
145+
multi_gpu_device_check()
146+
136147
logger = trt.Logger()
137148
runtime = trt.Runtime(logger)
138149
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
@@ -162,7 +173,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
162173
self._check_initialized()
163174

164175
# If in safe mode, check at each iteration for for whether a switch is required
165-
if torch_tensorrt._compile.SAFE_MODE:
176+
if (
177+
torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
178+
):
166179
curr_device_id = torch.cuda.current_device()
167180
curr_device_properties = torch.cuda.get_device_properties(
168181
curr_device_id
@@ -202,24 +215,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
202215
)
203216

204217
for i, input_name in enumerate(self.input_names):
205-
# Check that the inputs are on cuda and have the correct data type if in safe mode
206-
if torch_tensorrt._compile.SAFE_MODE:
207-
if not contiguous_inputs[i].is_cuda:
208-
logger.warning(
209-
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
210-
"This tensor is being moved by the runtime but for performance considerations, "
211-
"ensure your inputs are all on GPU and open an issue here "
212-
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
213-
)
214-
contiguous_inputs = (
215-
contiguous_inputs[:i]
216-
+ [contiguous_inputs[i].cuda()]
217-
+ contiguous_inputs[i + 1 :]
218-
)
219-
220-
assert (
221-
contiguous_inputs[i].dtype == self.input_dtypes[i]
222-
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
218+
if not contiguous_inputs[i].is_cuda:
219+
logger.warning(
220+
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
221+
"This tensor is being moved by the runtime but for performance considerations, "
222+
"ensure your inputs are all on GPU and open an issue here "
223+
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
224+
)
225+
contiguous_inputs = (
226+
contiguous_inputs[:i]
227+
+ [contiguous_inputs[i].cuda()]
228+
+ contiguous_inputs[i + 1 :]
229+
)
230+
231+
assert (
232+
contiguous_inputs[i].dtype == self.input_dtypes[i]
233+
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
223234

224235
idx = self.input_binding_indices_in_order[i]
225236
bindings[idx] = contiguous_inputs[i].data_ptr()

py/torch_tensorrt/dynamo/runtime/tools.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,27 @@
33

44
import torch
55

6+
import torch_tensorrt
7+
68
logger = logging.getLogger(__name__)
79

810

11+
def multi_gpu_device_check() -> None:
12+
# If multi-device safe mode is disabled and more than 1 device is registered on the machine, warn user
13+
if (
14+
not torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
15+
and torch.cuda.device_count() > 1
16+
):
17+
logger.warning(
18+
"Detected this engine is being instantitated in a multi-GPU system with "
19+
"multi-device safe mode disabled. For more on the implications of this "
20+
"as well as workarounds, see MULTI_DEVICE_SAFE_MODE.md "
21+
"(https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/runtime/MULTI_DEVICE_SAFE_MODE.md). "
22+
f"The engine is set to be instantiated on the current default cuda device, cuda:{torch.cuda.current_device()}. "
23+
"If this is incorrect, please set the desired cuda device via torch.cuda.set_device(...) and retry."
24+
)
25+
26+
927
def _is_switch_required(
1028
curr_device_id: int,
1129
engine_device_id: int,

py/torch_tensorrt/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .multi_device_safe_mode import set_multi_device_safe_mode
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import logging
2+
from importlib.util import find_spec
3+
from typing import Any
4+
5+
import torch
6+
7+
if find_spec("torch_tensorrt._C") is not None:
8+
_PY_RT_MULTI_DEVICE_SAFE_MODE = torch.ops.tensorrt.get_multi_device_safe_mode()
9+
else:
10+
_PY_RT_MULTI_DEVICE_SAFE_MODE = False
11+
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class _MultiDeviceSafeModeContextManager(object):
17+
"""Helper class used in conjunction with `set_multi_device_safe_mode`
18+
19+
Used to enable `set_multi_device_safe_mode` as a dual-purpose context manager
20+
"""
21+
22+
def __init__(self, old_mode: bool) -> None:
23+
self.old_mode = old_mode
24+
25+
def __enter__(self) -> "_MultiDeviceSafeModeContextManager":
26+
return self
27+
28+
def __exit__(self, *args: Any) -> None:
29+
# Set multi-device safe mode back to old mode in Python
30+
global _PY_RT_MULTI_DEVICE_SAFE_MODE
31+
_PY_RT_MULTI_DEVICE_SAFE_MODE = self.old_mode
32+
33+
# Set multi-device safe mode back to old mode in C++
34+
if find_spec("torch_tensorrt._C") is not None:
35+
torch.ops.tensorrt.set_multi_device_safe_mode(self.old_mode)
36+
37+
38+
def set_multi_device_safe_mode(mode: bool) -> _MultiDeviceSafeModeContextManager:
39+
# Fetch existing safe mode and set new mode for Python
40+
global _PY_RT_MULTI_DEVICE_SAFE_MODE
41+
old_mode = _PY_RT_MULTI_DEVICE_SAFE_MODE
42+
_PY_RT_MULTI_DEVICE_SAFE_MODE = mode
43+
44+
# Set new mode for C++
45+
if find_spec("torch_tensorrt._C") is not None:
46+
torch.ops.tensorrt.set_multi_device_safe_mode(mode)
47+
48+
logger.info(f"Set multi-device safe mode to {mode}")
49+
50+
# Return context manager in case the function is used in a `with` call
51+
return _MultiDeviceSafeModeContextManager(old_mode)

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ def run(self):
403403
"torch_tensorrt.fx.tracer",
404404
"torch_tensorrt.fx.tracer.acc_tracer",
405405
"torch_tensorrt.fx.tracer.dispatch_tracer",
406+
"torch_tensorrt.runtime",
406407
]
407408

408409
package_dir = {
@@ -430,6 +431,7 @@ def run(self):
430431
"torch_tensorrt.fx.tracer": "py/torch_tensorrt/fx/tracer",
431432
"torch_tensorrt.fx.tracer.acc_tracer": "py/torch_tensorrt/fx/tracer/acc_tracer",
432433
"torch_tensorrt.fx.tracer.dispatch_tracer": "py/torch_tensorrt/fx/tracer/dispatch_tracer",
434+
"torch_tensorrt.runtime": "py/torch_tensorrt/runtime",
433435
}
434436

435437
package_data = {}

0 commit comments

Comments
 (0)