Skip to content

Commit 48618f4

Browse files
committed
chore: Fix device dict
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent a12141c commit 48618f4

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

py/torch_tensorrt/dynamo/lower.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ def compile(
8282
raise ValueError("Invalid GPU ID provided for the CUDA device provided")
8383
elif isinstance(device, torch.device):
8484
device = device
85+
elif isinstance(device, dict):
86+
if "device_type" in device and device["device_type"] == trt.DeviceType.GPU:
87+
if "gpu_id" in device:
88+
device = torch.device(device["gpu_id"])
89+
else:
90+
device = torch.device("cuda:0")
8591
else:
8692
raise ValueError(
8793
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"

0 commit comments

Comments
 (0)