Skip to content

Commit 3f05c77

Browse files
committed
chore(//py/torch_tensorrt): Make Device conform to mypy
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent ac0126f commit 3f05c77

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

py/torch_tensorrt/_Device.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Self
1+
from typing import Self, Optional, Any, Tuple
22
import torch
33

44
# from torch_tensorrt import _enums
@@ -25,12 +25,12 @@ class Device(object):
2525
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
2626
"""
2727

28-
device_type = None #: (torch_tensorrt.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
29-
gpu_id = -1 #: (int) Device ID for target GPU
30-
dla_core = -1 #: (int) Core ID for target DLA core
31-
allow_gpu_fallback = False #: (bool) Whether falling back to GPU if DLA cannot support an op should be allowed
28+
device_type: Optional[trt.DeviceType] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
29+
gpu_id: int = -1 #: Device ID for target GPU
30+
dla_core: int = -1 #: Core ID for target DLA core
31+
allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
3232

33-
def __init__(self, *args, **kwargs):
33+
def __init__(self, *args: Any, **kwargs: Any):
3434
"""__init__ Method for torch_tensorrt.Device
3535
3636
Device accepts one of a few construction patterns
@@ -128,10 +128,11 @@ def _to_internal(self) -> _C.Device:
128128

129129
def _to_serialized_rt_device(self) -> str:
130130
internal_dev = self._to_internal()
131-
return internal_dev._to_serialized_rt_device()
131+
serialized_rt_device: str = internal_dev._to_serialized_rt_device()
132+
return serialized_rt_device
132133

133134
@classmethod
134-
def _from_torch_device(cls, torch_dev: torch.device):
135+
def _from_torch_device(cls, torch_dev: torch.device) -> Self:
135136
if torch_dev.type != "cuda":
136137
raise ValueError('Torch Device specs must have type "cuda"')
137138
gpu_id = torch_dev.index
@@ -147,10 +148,12 @@ def _current_device(cls) -> Self:
147148
return cls(gpu_id=dev.gpu_id)
148149

149150
@staticmethod
150-
def _parse_device_str(s):
151+
def _parse_device_str(s: str) -> Tuple[trt.DeviceType, int]:
151152
s = s.lower()
152153
spec = s.split(":")
153154
if spec[0] == "gpu" or spec[0] == "cuda":
154155
return (trt.DeviceType.GPU, int(spec[1]))
155156
elif spec[0] == "dla":
156157
return (trt.DeviceType.DLA, int(spec[1]))
158+
else:
159+
raise ValueError(f"Unknown device type {spec[0]}")

0 commit comments

Comments
 (0)