Skip to content

feat: Safety Mode for Runtime #2512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ TRTEngine::TRTEngine(
auto most_compatible_device = get_most_compatible_device(cuda_device);
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
device_info = most_compatible_device.value();
multi_gpu_device_check();
set_rt_device(device_info);

rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger()));
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
LOG_INFO("" << log_info);
}

{
if (MULTI_DEVICE_SAFE_MODE) {
std::unique_ptr<torch::autograd::profiler::RecordProfile> device_profiler_guard;
if (compiled_engine->profile_execution) {
device_profiler_guard =
Expand Down
4 changes: 4 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("execute_engine", execute_engine);
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
m.def("get_multi_device_safe_mode", []() -> bool { return MULTI_DEVICE_SAFE_MODE; });
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
});
}

} // namespace
Expand Down
17 changes: 15 additions & 2 deletions core/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace torch_tensorrt {
namespace core {
namespace runtime {

bool MULTI_DEVICE_SAFE_MODE = false;

c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
LOG_DEBUG("Target Device: " << target_device);
auto device_options = find_compatible_devices(target_device);
Expand All @@ -31,13 +33,13 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
if (device.device_name == target_device.device_name) {
// First priority is selecting a candidate which agrees with the current device ID
// If such a device is found, we can select it and break out of the loop
if (device.id == current_device.id && best_match.id != current_device.id) {
if (device.id == current_device.id) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changing these?

Copy link
Collaborator Author

@gs-olive gs-olive Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During review, it was determined that these are not actually necessary conditions. Since the ID of each device is unique on a given machine, it should not be possible that best_match.id == current_device.id at this point in the code, so the check is not needed here.

best_match = device;
break;
}
// Second priority is selecting a candidate which agrees with the target device ID
// At deserialization time, the current device and target device may not agree
else if (device.id == target_device.id && best_match.id != target_device.id) {
else if (device.id == target_device.id) {
best_match = device;
}
// If no such GPU ID is found, select the first available candidate GPU
Expand Down Expand Up @@ -103,6 +105,17 @@ RTDevice get_current_device() {
return RTDevice(device_id, nvinfer1::DeviceType::kGPU);
}

void multi_gpu_device_check() {
// If multi-device safe mode is disabled and more than 1 device is registered on the machine, warn user
if (!(MULTI_DEVICE_SAFE_MODE) && get_available_device_list().get_devices().size() > 1) {
LOG_WARNING(
"Detected this engine is being instantitated in a multi-GPU system with "
<< "multi-device safe mode disabled. For more on the implications of this "
<< "as well as workarounds, see the linked documentation "
<< "(https://pytorch.org/TensorRT/user_guide/runtime.html#multi-device-safe-mode)");
}
}

namespace {
static DeviceList cuda_device_list;
}
Expand Down
3 changes: 3 additions & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace runtime {

using EngineID = int64_t;
const std::string ABI_VERSION = "4";
extern bool MULTI_DEVICE_SAFE_MODE;
typedef enum {
ABI_TARGET_IDX = 0,
NAME_IDX,
Expand All @@ -33,6 +34,8 @@ std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device);

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);

void multi_gpu_device_check();

class DeviceList {
using DeviceMap = std::unordered_map<int, RTDevice>;
DeviceMap device_list;
Expand Down
34 changes: 34 additions & 0 deletions docsrc/user_guide/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,37 @@ Plugin Library
In the case you use Torch-TensorRT as a converter to a TensorRT engine and your engine uses plugins provided by Torch-TensorRT, Torch-TensorRT
ships the library ``libtorchtrt_plugins.so`` which contains the implementation of the TensorRT plugins used by Torch-TensorRT during
compilation. This library can be ``DL_OPEN`` or ``LD_PRELOAD`` similar to other TensorRT plugin libraries.

Multi Device Safe Mode
---------------

Multi-device safe mode is a setting in Torch-TensorRT which allows the user to determine whether
the runtime checks for device consistency prior to every inference call.

There is a non-negligible, fixed cost per-inference call when multi-device safe mode is enabled, which is why
it is now disabled by default. It can be controlled via the following convenience function which
doubles as a context manager.

.. code-block:: python

# Enables Multi Device Safe Mode
torch_tensorrt.runtime.set_multi_device_safe_mode(True)

# Disables Multi Device Safe Mode [Default Behavior]
torch_tensorrt.runtime.set_multi_device_safe_mode(False)

# Enables Multi Device Safe Mode, then resets the safe mode to its prior setting
with torch_tensorrt.runtime.set_multi_device_safe_mode(True):
...

TensorRT requires that each engine be associated with the CUDA context in the active thread from which it is invoked.
Therefore, if the device were to change in the active thread, which may be the case when invoking
engines on multiple GPUs from the same Python process, safe mode will cause Torch-TensorRT to display
an alert and switch GPUs accordingly. If safe mode were not enabled, there could be a mismatch in the engine
device and CUDA context device, which could lead the program to crash.

One technique for managing multiple TRT engines on different GPUs while not sacrificing performance for
multi-device safe mode is to use Python threads. Each thread is responsible for all of the TRT engines
on a single GPU, and the default CUDA device on each thread corresponds to the GPU for which it is
responsible (can be set via ``torch.cuda.set_device(...)``). In this way, multiple threads can be used in the same
Python script without needing to switch CUDA contexts and incur performance overhead.
8 changes: 5 additions & 3 deletions py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,17 @@ def _find_lib(name: str, paths: List[str]) -> str:
from torch_tensorrt._Device import Device # noqa: F401
from torch_tensorrt._enums import * # noqa: F403
from torch_tensorrt._Input import Input # noqa: F401
from torch_tensorrt.logging import *
from torch_tensorrt.ptq import *
from torch_tensorrt._utils import * # noqa: F403
from torch_tensorrt._utils import sanitized_torch_version
from torch_tensorrt.logging import *
from torch_tensorrt.ptq import *
from torch_tensorrt.runtime import * # noqa: F403

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from torch_tensorrt import dynamo # noqa: F401
from torch_tensorrt.dynamo import backend # noqa: F401

from torch_tensorrt import dynamo # noqa: F401


def _register_with_torch() -> None:
trtorch_dir = os.path.dirname(__file__)
Expand Down
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import io
from typing import Sequence

import tensorrt as trt
import torch
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_torch_inputs

import tensorrt as trt


def convert_module(
module: torch.fx.GraphModule,
Expand Down Expand Up @@ -72,6 +71,8 @@ def convert_module(
engine=interpreter_result.engine,
input_names=list(interpreter_result.input_names),
output_names=list(interpreter_result.output_names),
target_device=settings.device,
profiling_enabled=settings.debug,
)

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def is_node_supported(
node_name = ConverterRegistry.qualified_name_or_str(node.target)

if (
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
node in CONVERTERS or node.op == "get_attr"
) and node_name not in self.torch_executed_ops:
# If node is a proper, supported computational node, store the operator
if not node.is_impure():
if not node.is_impure() and node.op != "get_attr":
if node_name not in self.supported_operators:
self.supported_operators[node_name] = 1
else:
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ def is_node_supported(
node_name = ConverterRegistry.qualified_name_or_str(node.target)

if (
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
node in CONVERTERS or node.op == "get_attr"
) and node_name not in self.torch_executed_ops:
# If node is a proper, supported computational node, store the operator
if not node.is_impure():
if not node.is_impure() and node.op != "get_attr":
if node_name not in self.supported_operators:
self.supported_operators[node_name] = 1
else:
Expand Down
66 changes: 61 additions & 5 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from __future__ import annotations

import logging
from contextlib import nullcontext
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
import torch
from torch.nn import Module
from torch_tensorrt._Device import Device
from torch_tensorrt.dynamo.runtime.tools import (
_is_switch_required,
_select_rt_device,
multi_gpu_device_check,
)
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

import torch_tensorrt

logger = logging.getLogger(__name__)


Expand All @@ -23,13 +32,26 @@ def __init__(
engine: trt.ICudaEngine,
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
target_device: Device = Device._current_device(),
profiling_enabled: Optional[bool] = None,
):
super(PythonTorchTensorRTModule, self).__init__()
self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict)

# Run multi-gpu device check to validate engine instantiation
multi_gpu_device_check()

self.engine = engine
self.input_names = input_names if input_names is not None else []
self.output_names = output_names if output_names is not None else []
self.initialized = False
self.target_device_id = target_device.gpu_id
self.target_device_properties = torch.cuda.get_device_properties(
self.target_device_id
)
self.profiling_enabled = (
profiling_enabled if profiling_enabled is not None else False
)
self._initialize()

def _initialize(self) -> None:
Expand Down Expand Up @@ -119,6 +141,9 @@ def _load_from_state_dict(
) -> None:
engine_bytes = state_dict[prefix + "engine"]

# Run multi-gpu device check to validate engine instantiation
multi_gpu_device_check()

logger = trt.Logger()
runtime = trt.Runtime(logger)
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
Expand All @@ -141,15 +166,43 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
if self.engine:
self.context = self.engine.create_execution_context()

def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:Forward"
):
) if self.profiling_enabled else nullcontext():
self._check_initialized()

# If in safe mode, check at each iteration for for whether a switch is required
if (
torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
):
curr_device_id = torch.cuda.current_device()
curr_device_properties = torch.cuda.get_device_properties(
curr_device_id
)
logger.debug(f"Current Device: cuda:{curr_device_id}")

# If a switch is required, move all inputs to new device and set as active device
if _is_switch_required(
curr_device_id,
self.target_device_id,
curr_device_properties,
self.target_device_properties,
):
device_id, _ = _select_rt_device(
curr_device_id,
self.target_device_id,
self.target_device_properties,
)
device = torch.device(device_id)
torch.cuda.set_device(device_id)

inputs = tuple([tensor.to(device) for tensor in inputs])
logger.warning(f"Moved all input Tensors to cuda:{device_id}")

with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessInputs"
):
) if self.profiling_enabled else nullcontext():
assert len(inputs) == len(
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
Expand Down Expand Up @@ -188,7 +241,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:

with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
):
) if self.profiling_enabled else nullcontext():
# create output tensors
outputs: List[torch.Tensor] = []

Expand All @@ -215,7 +268,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:

with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:TensorRTRuntime"
):
) if self.profiling_enabled else nullcontext():
self.context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
)
Expand All @@ -235,6 +288,8 @@ def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None:
if not self.context.profiler:
self.context.profiler = trt.Profiler() if profiler is None else profiler

self.profiling_enabled = True

def disable_profiling(self) -> None:
"""
Disable TensorRT profiling.
Expand All @@ -244,6 +299,7 @@ def disable_profiling(self) -> None:
torch.cuda.synchronize()
del self.context
self.context = self.engine.create_execution_context()
self.profiling_enabled = False

def get_layer_info(self) -> str:
"""
Expand Down
Loading