Skip to content

Commit 36895d5

Browse files
committed
chore(//py/torch_tensorrt/ts): Make compile_spec conform to mypy
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 7f469ff commit 36895d5

File tree

1 file changed

+34
-31
lines changed

1 file changed

+34
-31
lines changed

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict, Any, Set
1+
from typing import List, Dict, Any, Set, Union, Optional
22
import torch
33
from torch_tensorrt import _C
44
import torch_tensorrt._C.ts as _ts_C
@@ -13,7 +13,7 @@
1313
import tensorrt as trt
1414

1515

16-
def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input:
16+
def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: # type: ignore[name-defined]
1717
clone = torch.classes.tensorrt._Input()
1818
clone._set_min(i.min)
1919
clone._set_opt(i.opt)
@@ -40,7 +40,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
4040
)
4141

4242

43-
def _parse_op_precision(precision: Any) -> _enums.dtype:
43+
def _parse_op_precision(precision: Any) -> _enums.dtype: # type: ignore[name-defined]
4444
if isinstance(precision, torch.dtype):
4545
if precision == torch.int8:
4646
return _enums.dtype.int8
@@ -64,7 +64,7 @@ def _parse_op_precision(precision: Any) -> _enums.dtype:
6464
)
6565

6666

67-
def _parse_enabled_precisions(precisions: Any) -> Set:
67+
def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: # type: ignore[name-defined]
6868
parsed_precisions = set()
6969
if any([isinstance(precisions, type) for type in [list, tuple, set]]):
7070
for p in precisions:
@@ -74,7 +74,7 @@ def _parse_enabled_precisions(precisions: Any) -> Set:
7474
return parsed_precisions
7575

7676

77-
def _parse_device_type(device: Any) -> _enums.DeviceType:
77+
def _parse_device_type(device: Any) -> _enums.DeviceType: # type: ignore[name-defined]
7878
if isinstance(device, torch.device):
7979
if device.type == "cuda":
8080
return _C.DeviceType.gpu
@@ -159,7 +159,7 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback:
159159
return info
160160

161161

162-
def _parse_input_signature(input_signature: Any, depth: int = 0):
162+
def _parse_input_signature(input_signature: Any, depth: int = 0) -> Any:
163163
if depth > 2:
164164
raise AssertionError(
165165
"Input nesting depth exceeds max supported depth, use 1 level: [A, B], or 2 level: [A, (B, C)]"
@@ -197,13 +197,16 @@ def _parse_input_signature(input_signature: Any, depth: int = 0):
197197
if i.shape_mode == Input._ShapeMode.STATIC:
198198
ts_i = TorchScriptInput(shape=i.shape, dtype=i.dtype, format=i.format)
199199
elif i.shape_mode == Input._ShapeMode.DYNAMIC:
200-
ts_i = TorchScriptInput(
201-
min_shape=i.shape["min_shape"],
202-
opt_shape=i.shape["opt_shape"],
203-
max_shape=i.shape["max_shape"],
204-
dtype=i.dtype,
205-
format=i.format,
206-
)
200+
if isinstance(i.shape, dict):
201+
ts_i = TorchScriptInput(
202+
min_shape=i.shape["min_shape"],
203+
opt_shape=i.shape["opt_shape"],
204+
max_shape=i.shape["max_shape"],
205+
dtype=i.dtype,
206+
format=i.format,
207+
)
208+
else:
209+
raise ValueError(f"Input set as dynamic, expected dictionary of shapes but found {i.shape}")
207210
else:
208211
raise ValueError(
209212
"Invalid shape mode detected for input while parsing the input_signature"
@@ -342,24 +345,24 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
342345

343346

344347
def TensorRTCompileSpec(
345-
inputs=[],
346-
input_signature=None,
347-
device=Device._current_device(),
348-
disable_tf32=False,
349-
sparse_weights=False,
350-
enabled_precisions=set(),
351-
refit=False,
352-
debug=False,
353-
capability=_enums.EngineCapability.default,
354-
num_avg_timing_iters=1,
355-
workspace_size=0,
356-
dla_sram_size=1048576,
357-
dla_local_dram_size=1073741824,
358-
dla_global_dram_size=536870912,
359-
truncate_long_and_double=False,
360-
calibrator=None,
361-
allow_shape_tensors=False,
362-
) -> torch.classes.tensorrt.CompileSpec:
348+
inputs: List[Union[torch.Tensor, Input]] = [],
349+
input_signature: Optional[Any] = None,
350+
device: Union[torch.Device, Device] = Device._current_device(),
351+
disable_tf32: bool = False,
352+
sparse_weights: bool = False,
353+
enabled_precisions: Set[Union[torch.dtype, _enums.dtype]] = set(), # type: ignore[name-defined]
354+
refit: bool = False,
355+
debug: bool = False,
356+
capability: _enums.EngineCapability = _enums.EngineCapability.default, # type: ignore[name-defined]
357+
num_avg_timing_iters: int = 1,
358+
workspace_size: int = 0,
359+
dla_sram_size: int = 1048576,
360+
dla_local_dram_size: int = 1073741824,
361+
dla_global_dram_size: int = 536870912,
362+
truncate_long_and_double: bool = False,
363+
calibrator: object = None,
364+
allow_shape_tensors: bool = False,
365+
) -> torch.classes.tensorrt.CompileSpec: # type: ignore[name-defined]
363366
"""Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
364367
365368
Keyword Args:

0 commit comments

Comments
 (0)