Skip to content

Commit b6dbc8c

Browse files
committed
chore: mypy compliance with 3.11 syntax
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8e92952 commit b6dbc8c

26 files changed

+326
-239
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ repos:
3838
rev: 'v1.4.1'
3939
hooks:
4040
- id: mypy
41+
exclude: "^py/torch_tensorrt/fx"
4142
- repo: local
4243
hooks:
4344
- id: dont-commit-upstream

py/torch_tensorrt/_Device.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from typing import Self, Optional, Any, Tuple
1+
from typing import TypeVar, Optional, Any, Tuple
2+
import sys
3+
4+
if sys.version_info >= (3, 11):
5+
from typing import Self
6+
else:
7+
from typing_extensions import Self
8+
29
import torch
310

411
# from torch_tensorrt import _enums
@@ -25,7 +32,9 @@ class Device(object):
2532
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
2633
"""
2734

28-
device_type: Optional[trt.DeviceType] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
35+
device_type: Optional[
36+
trt.DeviceType
37+
] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
2938
gpu_id: int = -1 #: Device ID for target GPU
3039
dla_core: int = -1 #: Core ID for target DLA core
3140
allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
@@ -140,11 +149,7 @@ def _from_torch_device(cls, torch_dev: torch.device) -> Self:
140149

141150
@classmethod
142151
def _current_device(cls) -> Self:
143-
try:
144-
dev = _C._get_current_device()
145-
except RuntimeError:
146-
logging.log(logging.Level.Error, "Cannot get current device")
147-
return None
152+
dev = _C._get_current_device()
148153
return cls(gpu_id=dev.gpu_id)
149154

150155
@staticmethod

py/torch_tensorrt/_Input.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import List, Dict, Any, Tuple, Optional, Union
2+
from typing import List, Dict, Any, Tuple, Optional, Union, Sequence
33

44
import torch
55

@@ -27,13 +27,17 @@ class _ShapeMode(Enum):
2727
STATIC = 0
2828
DYNAMIC = 1
2929

30-
shape_mode: Optional[_ShapeMode] = None #: Is input statically or dynamically shaped
31-
shape: Optional[Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
32-
dtype: _enums.dtype = ( # type: ignore[name-defined]
30+
shape_mode: Optional[
31+
_ShapeMode
32+
] = None #: Is input statically or dynamically shaped
33+
shape: Optional[
34+
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
35+
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
36+
dtype: _enums.dtype = ( # type: ignore[name-defined]
3337
_enums.dtype.unknown
3438
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
3539
_explicit_set_dtype: bool = False
36-
format: _enums.TensorFormat = ( # type: ignore[name-defined]
40+
format: _enums.TensorFormat = ( # type: ignore[name-defined]
3741
_enums.TensorFormat.contiguous
3842
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
3943

@@ -187,7 +191,9 @@ def __str__(self) -> str:
187191
str(self.tensor_domain[1]),
188192
)
189193
else:
190-
raise RuntimeError(f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})")
194+
raise RuntimeError(
195+
f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})"
196+
)
191197
else:
192198
raise RuntimeError("Unknown input shape mode")
193199

@@ -203,7 +209,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
203209
return False
204210

205211
@staticmethod
206-
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
212+
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
207213
if isinstance(dtype, torch.dtype):
208214
if dtype == torch.long:
209215
return _enums.dtype.long
@@ -231,7 +237,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
231237
)
232238

233239
@staticmethod
234-
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
240+
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
235241
if dtype == _enums.dtype.long:
236242
return torch.long
237243
elif dtype == _enums.dtype.int32:
@@ -250,7 +256,7 @@ def is_trt_dtype(self) -> bool:
250256
return bool(self.dtype != _enums.dtype.long)
251257

252258
@staticmethod
253-
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
259+
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
254260
if isinstance(format, torch.memory_format):
255261
if format == torch.contiguous_format:
256262
return _enums.TensorFormat.contiguous
@@ -270,7 +276,9 @@ def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defin
270276
)
271277

272278
@staticmethod
273-
def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple[float, float]:
279+
def _parse_tensor_domain(
280+
domain: Optional[Tuple[float, float]]
281+
) -> Tuple[float, float]:
274282
"""
275283
Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi)
276284
@@ -349,7 +357,7 @@ def from_tensor(
349357

350358
@classmethod
351359
def from_tensors(
352-
cls, ts: torch.Tensor, disable_memory_format_check: bool = False
360+
cls, ts: Sequence[torch.Tensor], disable_memory_format_check: bool = False
353361
) -> List["Input"]:
354362
"""
355363
Produce a list of Inputs which contain
@@ -369,7 +377,9 @@ def from_tensors(
369377
for t in ts
370378
]
371379

372-
def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Optional[torch.Tensor]:
380+
def example_tensor(
381+
self, optimization_profile_field: Optional[str] = None
382+
) -> torch.Tensor:
373383
"""
374384
Get an example tensor of the shape specified by the Input object
375385
@@ -388,7 +398,9 @@ def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Op
388398
if isinstance(self.shape, tuple):
389399
return torch.rand(self.shape).to(dtype=self.torch_dtype)
390400
else:
391-
RuntimeError(f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})")
401+
RuntimeError(
402+
f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})"
403+
)
392404
else:
393405
if optimization_profile_field is not None:
394406
try:
@@ -408,11 +420,12 @@ def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Op
408420
dtype=self.torch_dtype
409421
)
410422
else:
411-
raise RuntimeError(f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})")
423+
raise RuntimeError(
424+
f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})"
425+
)
412426

413427
else:
414428
raise ValueError(
415429
"Requested an example tensor from a dynamic shaped input but did not specific which profile field to use."
416430
)
417-
return None
418-
431+
raise

py/torch_tensorrt/_compile.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import torch_tensorrt.ts
44

5-
from torch_tensorrt import logging, Input, dtype
5+
from torch_tensorrt import logging
6+
from torch_tensorrt._Input import Input
7+
from torch_tensorrt._enums import dtype
68
import torch
79
import torch.fx
810
from enum import Enum
@@ -12,12 +14,18 @@
1214
from torch_tensorrt.fx.utils import LowerPrecision
1315

1416

15-
def _non_fx_input_interface(inputs: List[Input | torch.Tensor | InputTensorSpec]) -> TypeGuard[List[Input | torch.Tensor]]:
17+
def _non_fx_input_interface(
18+
inputs: List[Input | torch.Tensor | InputTensorSpec],
19+
) -> TypeGuard[List[Input | torch.Tensor]]:
1620
return all([isinstance(i, torch.Tensor | Input) for i in inputs])
1721

18-
def _fx_input_interface(inputs: List[Input | torch.Tensor | InputTensorSpec]) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]:
22+
23+
def _fx_input_interface(
24+
inputs: List[Input | torch.Tensor | InputTensorSpec],
25+
) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]:
1926
return all([isinstance(i, torch.Tensor | InputTensorSpec) for i in inputs])
2027

28+
2129
class _IRType(Enum):
2230
"""Enum to set the minimum required logging level to print a message to stdout"""
2331

@@ -89,10 +97,12 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
8997
def compile(
9098
module: Any,
9199
ir: str = "default",
92-
inputs: List[Union[Input, torch.Tensor, InputTensorSpec]] = [],
93-
enabled_precisions: Set[Union[torch.dtype, dtype]] = set([torch.float]),
100+
inputs: List[Input | torch.Tensor | InputTensorSpec] = [],
101+
enabled_precisions: Set[torch.dtype | dtype] = set([torch.float]),
94102
**kwargs: Any,
95-
) -> Union[torch.nn.Module, torch.jit.ScriptModule, torch.fx.GraphModule, Callable[[Any], Any]]:
103+
) -> (
104+
torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any]
105+
):
96106
"""Compile a PyTorch module for NVIDIA GPUs using TensorRT
97107
98108
Takes a existing PyTorch module and a set of settings to configure the compiler
@@ -168,7 +178,7 @@ def compile(
168178
)
169179
return compiled_fx_module
170180
elif target_ir == _IRType.dynamo:
171-
return torch_tensorrt.dynamo.compile(
181+
compiled_aten_module: torch.fx.GraphModule = torch_tensorrt.dynamo.compile(
172182
module,
173183
inputs=inputs,
174184
enabled_precisions=enabled_precisions,
@@ -198,8 +208,8 @@ def convert_method_to_trt_engine(
198208
module: Any,
199209
method_name: str,
200210
ir: str = "default",
201-
inputs: List[Union[Input, torch.Tensor]] = [],
202-
enabled_precisions: Set[Union[torch.dtype, dtype]] = set([torch.float]),
211+
inputs: List[Input | torch.Tensor] = [],
212+
enabled_precisions: Set[torch.dtype | dtype] = set([torch.float]),
203213
**kwargs: Any,
204214
) -> bytes:
205215
"""Convert a TorchScript module method to a serialized TensorRT engine

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
MAX_AUX_STREAMS = None
99
VERSION_COMPATIBLE = False
1010
OPTIMIZATION_LEVEL = None
11-
USE_PYTHON_RUNTIME = None
1211
TRUNCATE_LONG_AND_DOUBLE = False
12+
USE_PYTHON_RUNTIME = False

py/torch_tensorrt/dynamo/aten_tracer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
specialize_int: bool = True,
4646
verbose: bool = True,
4747
) -> None:
48-
4948
self.capture_scalar_outputs = capture_scalar_outputs
5049
self.guard_nn_modules = guard_nn_modules
5150
self.dynamic_shapes = dynamic_shapes
@@ -128,7 +127,11 @@ def dynamo_trace(
128127

129128

130129
@req_torch_version("2.dev")
131-
def trace(model: Union[torch.nn.Module, torch.fx.GraphModule], inputs: Tuple[Any, ...], **kwargs: Any) -> torch.fx.GraphModule:
130+
def trace(
131+
model: torch.nn.Module | torch.fx.GraphModule,
132+
inputs: Tuple[Any, ...],
133+
**kwargs: Any,
134+
) -> torch.fx.GraphModule:
132135
"""
133136
Optimized trace with necessary passes which re-compose some ops or replace some ops
134137
These passes should be general and functional purpose
@@ -149,11 +152,11 @@ def trace(model: Union[torch.nn.Module, torch.fx.GraphModule], inputs: Tuple[Any
149152
fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
150153
print(fx_module.graph)
151154
for passes in passes_list:
152-
pr: PassResult = passes(fx_module) #type: ignore[assignment] #The type hints in fx are wrong
155+
pr: PassResult = passes(fx_module) # type: ignore[assignment] #The type hints in fx are wrong
153156
fx_module = pr.graph_module
154157

155158
fx_module(*inputs)
156159

157160
fx_module = run_const_fold(fx_module)
158161
print(fx_module.graph)
159-
return fx_module #type: ignore[no-any-return]
162+
return fx_module # type: ignore[no-any-return]

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,19 @@
2727
logger = logging.getLogger(__name__)
2828

2929

30-
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
30+
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
3131
def torch_tensorrt_backend(
32-
gm: torch.fx.GraphModule,
33-
sample_inputs: Sequence[torch.Tensor],
34-
**kwargs: Any
32+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
3533
) -> torch.nn.Module:
3634
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3735

3836
compiled_mod: torch.nn.Module = DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
3937
return compiled_mod
4038

41-
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
39+
40+
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
4241
def aot_torch_tensorrt_aten_backend(
43-
gm: torch.fx.GraphModule,
44-
sample_inputs: Sequence[torch.Tensor],
45-
**kwargs: Any
42+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
4643
) -> torch.nn.Module:
4744
settings = parse_dynamo_kwargs(kwargs)
4845

@@ -58,7 +55,7 @@ def aot_torch_tensorrt_aten_backend(
5855
return aot_module_simplified(
5956
gm,
6057
sample_inputs,
61-
fw_compiler=make_boxed_compiler(custom_backend), # type: ignore[no-untyped-call]
58+
fw_compiler=make_boxed_compiler(custom_backend), # type: ignore[no-untyped-call]
6259
decompositions=get_decompositions(),
6360
)
6461

0 commit comments

Comments
 (0)