Skip to content

Commit 338e542

Browse files
authored
fix: Address multi-GPU issue in engine deserialize (#2325)
1 parent 76de80d commit 338e542

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

core/runtime/execute_engine.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ bool is_switch_required(const RTDevice& curr_device, const RTDevice& engine_devi
4343
return false;
4444
}
4545

46-
RTDevice select_rt_device(const RTDevice& engine_device) {
47-
auto new_target_device_opt = get_most_compatible_device(engine_device);
46+
RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_device) {
47+
auto new_target_device_opt = get_most_compatible_device(engine_device, curr_device);
4848

4949
// REVIEW: THIS DOES NOT LIST DLA PROBABLY, WHICH WE SHOULD
5050
// TODO: I think this logic could be way simpler at execution time since if the tensors arent on the right
@@ -89,7 +89,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
8989

9090
if (is_switch_required(curr_device, compiled_engine->device_info)) {
9191
// Scan through available CUDA devices and set the CUDA device context correctly
92-
RTDevice device = select_rt_device(compiled_engine->device_info);
92+
RTDevice device = select_rt_device(compiled_engine->device_info, curr_device);
9393
set_rt_device(device);
9494

9595
// Target device is new device

core/runtime/runtime.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,16 @@ namespace torch_tensorrt {
77
namespace core {
88
namespace runtime {
99

10-
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device) {
10+
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
1111
LOG_DEBUG("Target Device: " << target_device);
1212
auto device_options = find_compatible_devices(target_device);
13+
RTDevice current_device;
14+
if (current_device.id == -1) {
15+
current_device = get_current_device();
16+
} else {
17+
current_device = curr_device;
18+
}
19+
1320
if (device_options.size() == 0) {
1421
return {};
1522
} else if (device_options.size() == 1) {
@@ -21,10 +28,20 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
2128
dev_list << "[" << std::endl;
2229
for (auto device : device_options) {
2330
dev_list << " " << device << ',' << std::endl;
24-
if (device.device_name == target_device.device_name && best_match.device_name != target_device.device_name) {
25-
best_match = device;
26-
} else if (device.device_name == target_device.device_name && best_match.device_name == target_device.device_name) {
27-
if (device.id == target_device.id && best_match.id != target_device.id) {
31+
if (device.device_name == target_device.device_name) {
32+
// First priority is selecting a candidate which agrees with the current device ID
33+
// If such a device is found, we can select it and break out of the loop
34+
if (device.id == current_device.id && best_match.id != current_device.id) {
35+
best_match = device;
36+
break;
37+
}
38+
// Second priority is selecting a candidate which agrees with the target device ID
39+
// At deserialization time, the current device and target device may not agree
40+
else if (device.id == target_device.id && best_match.id != target_device.id) {
41+
best_match = device;
42+
}
43+
// If no such GPU ID is found, select the first available candidate GPU
44+
else if (best_match.device_name != target_device.device_name) {
2845
best_match = device;
2946
}
3047
}

core/runtime/runtime.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ typedef enum {
2626
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
2727
} SerializedInfoIndex;
2828

29-
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device);
29+
c10::optional<RTDevice> get_most_compatible_device(
30+
const RTDevice& target_device,
31+
const RTDevice& curr_device = RTDevice());
3032
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device);
3133

3234
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);

0 commit comments

Comments
 (0)