1
- from typing import Self
1
+ from typing import Self , Optional , Any , Tuple
2
2
import torch
3
3
4
4
# from torch_tensorrt import _enums
@@ -25,12 +25,12 @@ class Device(object):
25
25
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
26
26
"""
27
27
28
- device_type = None #: (torch_tensorrt.DeviceType) : Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
29
- gpu_id = - 1 #: (int) Device ID for target GPU
30
- dla_core = - 1 #: (int) Core ID for target DLA core
31
- allow_gpu_fallback = False #: (bool) Whether falling back to GPU if DLA cannot support an op should be allowed
28
+ device_type : Optional [ trt . DeviceType ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
29
+ gpu_id : int = - 1 #: Device ID for target GPU
30
+ dla_core : int = - 1 #: Core ID for target DLA core
31
+ allow_gpu_fallback : bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
32
32
33
- def __init__ (self , * args , ** kwargs ):
33
+ def __init__ (self , * args : Any , ** kwargs : Any ):
34
34
"""__init__ Method for torch_tensorrt.Device
35
35
36
36
Device accepts one of a few construction patterns
@@ -128,10 +128,11 @@ def _to_internal(self) -> _C.Device:
128
128
129
129
def _to_serialized_rt_device (self ) -> str :
130
130
internal_dev = self ._to_internal ()
131
- return internal_dev ._to_serialized_rt_device ()
131
+ serialized_rt_device : str = internal_dev ._to_serialized_rt_device ()
132
+ return serialized_rt_device
132
133
133
134
@classmethod
134
- def _from_torch_device (cls , torch_dev : torch .device ):
135
+ def _from_torch_device (cls , torch_dev : torch .device ) -> Self :
135
136
if torch_dev .type != "cuda" :
136
137
raise ValueError ('Torch Device specs must have type "cuda"' )
137
138
gpu_id = torch_dev .index
@@ -147,10 +148,12 @@ def _current_device(cls) -> Self:
147
148
return cls (gpu_id = dev .gpu_id )
148
149
149
150
@staticmethod
150
- def _parse_device_str (s ) :
151
+ def _parse_device_str (s : str ) -> Tuple [ trt . DeviceType , int ] :
151
152
s = s .lower ()
152
153
spec = s .split (":" )
153
154
if spec [0 ] == "gpu" or spec [0 ] == "cuda" :
154
155
return (trt .DeviceType .GPU , int (spec [1 ]))
155
156
elif spec [0 ] == "dla" :
156
157
return (trt .DeviceType .DLA , int (spec [1 ]))
158
+ else :
159
+ raise ValueError (f"Unknown device type { spec [0 ]} " )
0 commit comments