@@ -7,9 +7,16 @@ namespace torch_tensorrt {
7
7
namespace core {
8
8
namespace runtime {
9
9
10
- c10::optional<RTDevice> get_most_compatible_device (const RTDevice& target_device) {
10
+ c10::optional<RTDevice> get_most_compatible_device (const RTDevice& target_device, const RTDevice& curr_device ) {
11
11
LOG_DEBUG (" Target Device: " << target_device);
12
12
auto device_options = find_compatible_devices (target_device);
13
+ RTDevice current_device;
14
+ if (current_device.id == -1 ) {
15
+ current_device = get_current_device ();
16
+ } else {
17
+ current_device = curr_device;
18
+ }
19
+
13
20
if (device_options.size () == 0 ) {
14
21
return {};
15
22
} else if (device_options.size () == 1 ) {
@@ -21,10 +28,20 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
21
28
dev_list << " [" << std::endl;
22
29
for (auto device : device_options) {
23
30
dev_list << " " << device << ' ,' << std::endl;
24
- if (device.device_name == target_device.device_name && best_match.device_name != target_device.device_name ) {
25
- best_match = device;
26
- } else if (device.device_name == target_device.device_name && best_match.device_name == target_device.device_name ) {
27
- if (device.id == target_device.id && best_match.id != target_device.id ) {
31
+ if (device.device_name == target_device.device_name ) {
32
+ // First priority is selecting a candidate which agrees with the current device ID
33
+ // If such a device is found, we can select it and break out of the loop
34
+ if (device.id == current_device.id && best_match.id != current_device.id ) {
35
+ best_match = device;
36
+ break ;
37
+ }
38
+ // Second priority is selecting a candidate which agrees with the target device ID
39
+ // At deserialization time, the current device and target device may not agree
40
+ else if (device.id == target_device.id && best_match.id != target_device.id ) {
41
+ best_match = device;
42
+ }
43
+ // If no such GPU ID is found, select the first available candidate GPU
44
+ else if (best_match.device_name != target_device.device_name ) {
28
45
best_match = device;
29
46
}
30
47
}
0 commit comments