Skip to content

Commit 1e1c7a8

Browse files
authored
feat: Add support for multi-device safe mode in C++ (#2824)
1 parent 9c18138 commit 1e1c7a8

File tree

5 files changed

+58
-0
lines changed

5 files changed

+58
-0
lines changed

core/runtime/runtime.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ void multi_gpu_device_check() {
121121
}
122122
}
123123

124+
bool get_multi_device_safe_mode() {
125+
return MULTI_DEVICE_SAFE_MODE;
126+
}
127+
128+
void set_multi_device_safe_mode(bool multi_device_safe_mode) {
129+
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
130+
}
131+
124132
namespace {
125133
static DeviceList cuda_device_list;
126134
}

core/runtime/runtime.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
3838

3939
void multi_gpu_device_check();
4040

41+
bool get_multi_device_safe_mode();
42+
43+
void set_multi_device_safe_mode(bool multi_device_safe_mode);
44+
4145
class DeviceList {
4246
using DeviceMap = std::unordered_map<int, RTDevice>;
4347
DeviceMap device_list;

tests/core/runtime/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
load("//tests/core/runtime:runtime_test.bzl", "runtime_test")
2+
13
package(default_visibility = ["//visibility:public"])
24

35
config_setting(
@@ -6,3 +8,14 @@ config_setting(
68
"define": "abi=pre_cxx11_abi",
79
},
810
)
11+
12+
runtime_test(
13+
name = "test_multi_device_safe_mode",
14+
)
15+
16+
test_suite(
17+
name = "runtime_tests",
18+
tests = [
19+
":test_multi_device_safe_mode",
20+
],
21+
)

tests/core/runtime/runtime_test.bzl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""
2+
runtime test macros
3+
"""
4+
5+
load("@rules_cc//cc:defs.bzl", "cc_test")
6+
7+
def runtime_test(name, visibility = None):
8+
"""Macro to define a runtime test
9+
10+
Args:
11+
name: Name of test file
12+
visibility: Visibility of the test target
13+
"""
14+
cc_test(
15+
name = name,
16+
srcs = [name + ".cpp"],
17+
visibility = visibility,
18+
deps = [
19+
"//tests/util",
20+
"@googletest//:gtest_main",
21+
] + select({
22+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
23+
"//conditions:default": ["@libtorch//:libtorch"],
24+
}),
25+
)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include "core/runtime/runtime.h"
2+
#include "gtest/gtest.h"
3+
4+
TEST(Runtime, MultiDeviceSafeMode) {
5+
ASSERT_TRUE(!torch_tensorrt::core::runtime::get_multi_device_safe_mode());
6+
torch_tensorrt::core::runtime::set_multi_device_safe_mode(true);
7+
ASSERT_TRUE(torch_tensorrt::core::runtime::get_multi_device_safe_mode());
8+
}

0 commit comments

Comments
 (0)