Skip to content

Commit 91c93f5

Browse files
committed
feat: Safety Mode for Runtime
- Add safety mode for Torch-TensorRT runtime - Add C++ TorchBind bindings and relevant lambda functions to get and set necessary attributes - Add runtime augmentations to support different modes - Add testing for safe mode settings
1 parent 867dc7b commit 91c93f5

File tree

7 files changed

+87
-4
lines changed

7 files changed

+87
-4
lines changed

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
7474
LOG_INFO("" << log_info);
7575
}
7676

77-
{
77+
if (SAFE_MODE) {
7878
std::unique_ptr<torch::autograd::profiler::RecordProfile> device_profiler_guard;
7979
if (compiled_engine->profile_execution) {
8080
device_profiler_guard =

core/runtime/register_jit_hooks.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ TORCH_LIBRARY(tensorrt, m) {
114114
m.def("execute_engine", execute_engine);
115115
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
116116
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
117+
m.def("get_safe_mode", []() -> bool { return SAFE_MODE; });
118+
m.def("set_safe_mode", [](bool safe_mode) -> void { SAFE_MODE = safe_mode; });
117119
}
118120

119121
} // namespace

core/runtime/runtime.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ namespace torch_tensorrt {
77
namespace core {
88
namespace runtime {
99

10+
bool SAFE_MODE = true;
11+
1012
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
1113
LOG_DEBUG("Target Device: " << target_device);
1214
auto device_options = find_compatible_devices(target_device);

core/runtime/runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace runtime {
1616

1717
using EngineID = int64_t;
1818
const std::string ABI_VERSION = "4";
19+
extern bool SAFE_MODE;
1920
typedef enum {
2021
ABI_TARGET_IDX = 0,
2122
NAME_IDX,

py/torch_tensorrt/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,23 @@ def _find_lib(name: str, paths: List[str]) -> str:
8282

8383
import torch
8484
from torch_tensorrt._compile import * # noqa: F403
85+
from torch_tensorrt._compile import (
86+
enable_safe_inference_mode,
87+
enable_unsafe_inference_mode,
88+
)
8589
from torch_tensorrt._Device import Device # noqa: F401
8690
from torch_tensorrt._enums import * # noqa: F403
8791
from torch_tensorrt._Input import Input # noqa: F401
88-
from torch_tensorrt.logging import *
89-
from torch_tensorrt.ptq import *
9092
from torch_tensorrt._utils import * # noqa: F403
9193
from torch_tensorrt._utils import sanitized_torch_version
94+
from torch_tensorrt.logging import *
95+
from torch_tensorrt.ptq import *
9296

9397
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
94-
from torch_tensorrt import dynamo # noqa: F401
9598
from torch_tensorrt.dynamo import backend # noqa: F401
9699

100+
from torch_tensorrt import dynamo # noqa: F401
101+
97102

98103
def _register_with_torch() -> None:
99104
trtorch_dir = os.path.dirname(__file__)

py/torch_tensorrt/_compile.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,22 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
256256
return boxed_fn
257257

258258

259+
def enable_unsafe_inference_mode():
260+
"""
261+
Enables unsafe inference mode for Torch-TensorRT
262+
"""
263+
torch.ops.tensorrt.set_safe_mode(False)
264+
logger.info("Enabled unsafe inference mode")
265+
266+
267+
def enable_safe_inference_mode():
268+
"""
269+
Enables safe inference mode for Torch-TensorRT
270+
"""
271+
torch.ops.tensorrt.set_safe_mode(True)
272+
logger.info("Enabled safe inference mode")
273+
274+
259275
def convert_method_to_trt_engine(
260276
module: Any,
261277
method_name: str = "forward",
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
from torch.testing._internal.common_utils import TestCase, run_tests
3+
4+
import torch_tensorrt
5+
6+
7+
class TestSafeMode(TestCase):
8+
def test_safe_mode_enabled(self):
9+
torch_tensorrt.enable_safe_inference_mode()
10+
self.assertTrue(torch.ops.tensorrt.get_safe_mode())
11+
12+
def test_unsafe_mode_enabled(self):
13+
torch_tensorrt.enable_unsafe_inference_mode()
14+
self.assertFalse(torch.ops.tensorrt.get_safe_mode())
15+
16+
def test_unsafe_mode_enabled_inference(self):
17+
torch_tensorrt.enable_unsafe_inference_mode()
18+
19+
class SampleModel(torch.nn.Module):
20+
def forward(self, x):
21+
return torch.softmax((x + 2) * 7, dim=0)
22+
23+
inputs = [
24+
torch.tensor(
25+
3,
26+
5,
27+
7,
28+
).cuda()
29+
]
30+
31+
fx_graph = torch.fx.symbolic_trace(SampleModel())
32+
33+
# Validate that the results between Torch and Torch-TRT are similar
34+
optimized_model = torch_tensorrt.compile(
35+
fx_graph,
36+
"torch_compile",
37+
inputs,
38+
min_block_size=1,
39+
pass_through_build_failures=True,
40+
use_python_runtime=True,
41+
)
42+
optimized_model_results = optimized_model(*inputs).detach().cpu()
43+
torch_model_results = fx_graph(*inputs).detach().cpu()
44+
45+
max_diff = float(
46+
torch.max(torch.abs(optimized_model_results - torch_model_results))
47+
)
48+
self.assertAlmostEqual(
49+
max_diff,
50+
0,
51+
msg=f"Unsafe Mode TRT outputs don't match with the original model.",
52+
)
53+
torch._dynamo.reset()
54+
55+
56+
if __name__ == "__main__":
57+
run_tests()

0 commit comments

Comments
 (0)