We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a12141c commit 48618f4Copy full SHA for 48618f4
py/torch_tensorrt/dynamo/lower.py
@@ -82,6 +82,12 @@ def compile(
82
raise ValueError("Invalid GPU ID provided for the CUDA device provided")
83
elif isinstance(device, torch.device):
84
device = device
85
+ elif isinstance(device, dict):
86
+ if "device_type" in device and device["device_type"] == trt.DeviceType.GPU:
87
+ if "gpu_id" in device:
88
+ device = torch.device(device["gpu_id"])
89
+ else:
90
+ device = torch.device("cuda:0")
91
else:
92
raise ValueError(
93
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
0 commit comments