Skip to content

Commit 7f469ff

Browse files
committed
chore(//py/torch_tensorrt): Making Input mypy compilaint
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 3f05c77 commit 7f469ff

File tree

1 file changed

+59
-54
lines changed

1 file changed

+59
-54
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,20 @@ class _ShapeMode(Enum):
2727
STATIC = 0
2828
DYNAMIC = 1
2929

30-
shape_mode = None #: (torch_tensorrt.Input._ShapeMode): Is input statically or dynamically shaped
31-
shape: Union[Tuple, Dict] = None #: (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
32-
dtype = (
30+
shape_mode: Optional[_ShapeMode] = None #: Is input statically or dynamically shaped
31+
shape: Optional[Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
32+
dtype: _enums.dtype = ( # type: ignore[name-defined]
3333
_enums.dtype.unknown
3434
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
35-
_explicit_set_dtype = False
36-
format = (
35+
_explicit_set_dtype: bool = False
36+
format: _enums.TensorFormat = ( # type: ignore[name-defined]
3737
_enums.TensorFormat.contiguous
3838
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
3939

40-
DOMAIN_OFFSET = 2.0
41-
low_tensor_domain_incl = 0.0
42-
high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET
43-
torch_dtype = torch.float32
40+
DOMAIN_OFFSET: float = 2.0
41+
low_tensor_domain_incl: float = 0.0
42+
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
43+
torch_dtype: torch.dtype = torch.float32
4444

4545
def __init__(self, *args: Any, **kwargs: Any) -> None:
4646
"""__init__ Method for torch_tensorrt.Input
@@ -165,15 +165,18 @@ def __str__(self) -> str:
165165
str(self.tensor_domain[1]),
166166
)
167167
elif self.shape_mode == Input._ShapeMode.DYNAMIC:
168-
return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={}, domain=[{}, {}))".format(
169-
self.shape["min_shape"],
170-
self.shape["opt_shape"],
171-
self.shape["max_shape"],
172-
str(self.dtype),
173-
str(self.format),
174-
str(self.tensor_domain[0]),
175-
str(self.tensor_domain[1]),
176-
)
168+
if isinstance(self.shape, dict):
169+
return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={}, domain=[{}, {}))".format(
170+
self.shape["min_shape"],
171+
self.shape["opt_shape"],
172+
self.shape["max_shape"],
173+
str(self.dtype),
174+
str(self.format),
175+
str(self.tensor_domain[0]),
176+
str(self.tensor_domain[1]),
177+
)
178+
else:
179+
raise RuntimeError(f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})")
177180
else:
178181
raise RuntimeError("Unknown input shape mode")
179182

@@ -189,7 +192,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
189192
return False
190193

191194
@staticmethod
192-
def _parse_dtype(dtype: Any) -> _enums.dtype:
195+
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
193196
if isinstance(dtype, torch.dtype):
194197
if dtype == torch.long:
195198
return _enums.dtype.long
@@ -217,7 +220,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype:
217220
)
218221

219222
@staticmethod
220-
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
223+
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
221224
if dtype == _enums.dtype.long:
222225
return torch.long
223226
elif dtype == _enums.dtype.int32:
@@ -233,10 +236,10 @@ def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
233236
return torch.float32
234237

235238
def is_trt_dtype(self) -> bool:
236-
return self.dtype != _enums.dtype.long
239+
return bool(self.dtype != _enums.dtype.long)
237240

238241
@staticmethod
239-
def _parse_format(format: Any) -> _enums.TensorFormat:
242+
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
240243
if isinstance(format, torch.memory_format):
241244
if format == torch.contiguous_format:
242245
return _enums.TensorFormat.contiguous
@@ -256,7 +259,7 @@ def _parse_format(format: Any) -> _enums.TensorFormat:
256259
)
257260

258261
@staticmethod
259-
def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple:
262+
def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple[float, float]:
260263
"""
261264
Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi)
262265
@@ -355,7 +358,7 @@ def from_tensors(
355358
for t in ts
356359
]
357360

358-
def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor:
361+
def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Optional[torch.Tensor]:
359362
"""
360363
Get an example tensor of the shape specified by the Input object
361364
@@ -365,38 +368,40 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor
365368
Returns:
366369
A PyTorch Tensor
367370
"""
368-
if optimization_profile_field is not None:
369-
try:
370-
assert any(
371-
[
372-
optimization_profile_field == field_name
373-
for field_name in ["min_shape", "opt_shape", "max_shape"]
374-
]
375-
)
376-
except:
371+
if self.shape_mode == Input._ShapeMode.STATIC:
372+
if optimization_profile_field is not None:
377373
raise ValueError(
378-
"Invalid field name, expected one of min_shape, opt_shape, max_shape"
374+
"Specified a optimization profile field but the input is static"
379375
)
376+
else:
377+
if isinstance(self.shape, tuple):
378+
return torch.rand(self.shape).to(dtype=self.torch_dtype)
379+
else:
380+
RuntimeError(f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})")
381+
else:
382+
if optimization_profile_field is not None:
383+
try:
384+
assert any(
385+
[
386+
optimization_profile_field == field_name
387+
for field_name in ["min_shape", "opt_shape", "max_shape"]
388+
]
389+
)
390+
except:
391+
raise ValueError(
392+
"Invalid field name, expected one of min_shape, opt_shape, max_shape"
393+
)
380394

381-
if (
382-
optimization_profile_field is not None
383-
and self.shape_mode == Input._ShapeMode.STATIC
384-
):
385-
raise ValueError(
386-
"Specified a optimization profile field but the input is static"
387-
)
395+
if isinstance(self.shape, dict):
396+
return torch.rand(self.shape[optimization_profile_field]).to(
397+
dtype=self.torch_dtype
398+
)
399+
else:
400+
raise RuntimeError(f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})")
401+
402+
else:
403+
raise ValueError(
404+
"Requested an example tensor from a dynamic shaped input but did not specific which profile field to use."
405+
)
388406

389-
if (
390-
optimization_profile_field is None
391-
and self.shape_mode == Input._ShapeMode.DYNAMIC
392-
):
393-
raise ValueError(
394-
"Requested an example tensor from a dynamic shaped input but did not specific which profile field to use."
395-
)
396407

397-
if self.shape_mode == Input._ShapeMode.STATIC:
398-
return torch.rand(self.shape).to(dtype=self.torch_dtype)
399-
else:
400-
return torch.rand(self.shape[optimization_profile_field]).to(
401-
dtype=self.torch_dtype
402-
)

0 commit comments

Comments
 (0)