Skip to content

Commit 6660b3b

Browse files
committed
chore: mypy compliance with 3.11 syntax
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e160a30 commit 6660b3b

26 files changed

+333
-244
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ repos:
3838
rev: 'v1.4.1'
3939
hooks:
4040
- id: mypy
41+
exclude: "^py/torch_tensorrt/fx"
4142
- repo: local
4243
hooks:
4344
- id: dont-commit-upstream

py/torch_tensorrt/_Device.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from typing import Self, Optional, Any, Tuple
1+
from typing import TypeVar, Optional, Any, Tuple
2+
import sys
3+
4+
if sys.version_info >= (3, 11):
5+
from typing import Self
6+
else:
7+
from typing_extensions import Self
8+
29
import torch
310

411
# from torch_tensorrt import _enums
@@ -25,7 +32,9 @@ class Device(object):
2532
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
2633
"""
2734

28-
device_type: Optional[trt.DeviceType] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
35+
device_type: Optional[
36+
trt.DeviceType
37+
] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
2938
gpu_id: int = -1 #: Device ID for target GPU
3039
dla_core: int = -1 #: Core ID for target DLA core
3140
allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
@@ -140,11 +149,7 @@ def _from_torch_device(cls, torch_dev: torch.device) -> Self:
140149

141150
@classmethod
142151
def _current_device(cls) -> Self:
143-
try:
144-
dev = _C._get_current_device()
145-
except RuntimeError:
146-
logging.log(logging.Level.Error, "Cannot get current device")
147-
return None
152+
dev = _C._get_current_device()
148153
return cls(gpu_id=dev.gpu_id)
149154

150155
@staticmethod

py/torch_tensorrt/_Input.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import List, Dict, Any, Tuple, Optional, Union
2+
from typing import List, Dict, Any, Tuple, Optional, Union, Sequence
33

44
import torch
55

@@ -27,13 +27,17 @@ class _ShapeMode(Enum):
2727
STATIC = 0
2828
DYNAMIC = 1
2929

30-
shape_mode: Optional[_ShapeMode] = None #: Is input statically or dynamically shaped
31-
shape: Optional[Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
32-
dtype: _enums.dtype = ( # type: ignore[name-defined]
30+
shape_mode: Optional[
31+
_ShapeMode
32+
] = None #: Is input statically or dynamically shaped
33+
shape: Optional[
34+
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
35+
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
36+
dtype: _enums.dtype = ( # type: ignore[name-defined]
3337
_enums.dtype.unknown
3438
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
3539
_explicit_set_dtype: bool = False
36-
format: _enums.TensorFormat = ( # type: ignore[name-defined]
40+
format: _enums.TensorFormat = ( # type: ignore[name-defined]
3741
_enums.TensorFormat.contiguous
3842
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
3943

@@ -176,7 +180,9 @@ def __str__(self) -> str:
176180
str(self.tensor_domain[1]),
177181
)
178182
else:
179-
raise RuntimeError(f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})")
183+
raise RuntimeError(
184+
f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})"
185+
)
180186
else:
181187
raise RuntimeError("Unknown input shape mode")
182188

@@ -192,7 +198,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
192198
return False
193199

194200
@staticmethod
195-
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
201+
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
196202
if isinstance(dtype, torch.dtype):
197203
if dtype == torch.long:
198204
return _enums.dtype.long
@@ -220,7 +226,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
220226
)
221227

222228
@staticmethod
223-
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
229+
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
224230
if dtype == _enums.dtype.long:
225231
return torch.long
226232
elif dtype == _enums.dtype.int32:
@@ -239,7 +245,7 @@ def is_trt_dtype(self) -> bool:
239245
return bool(self.dtype != _enums.dtype.long)
240246

241247
@staticmethod
242-
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
248+
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
243249
if isinstance(format, torch.memory_format):
244250
if format == torch.contiguous_format:
245251
return _enums.TensorFormat.contiguous
@@ -259,7 +265,9 @@ def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defin
259265
)
260266

261267
@staticmethod
262-
def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple[float, float]:
268+
def _parse_tensor_domain(
269+
domain: Optional[Tuple[float, float]]
270+
) -> Tuple[float, float]:
263271
"""
264272
Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi)
265273
@@ -338,7 +346,7 @@ def from_tensor(
338346

339347
@classmethod
340348
def from_tensors(
341-
cls, ts: torch.Tensor, disable_memory_format_check: bool = False
349+
cls, ts: Sequence[torch.Tensor], disable_memory_format_check: bool = False
342350
) -> List["Input"]:
343351
"""
344352
Produce a list of Inputs which contain
@@ -358,7 +366,9 @@ def from_tensors(
358366
for t in ts
359367
]
360368

361-
def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Optional[torch.Tensor]:
369+
def example_tensor(
370+
self, optimization_profile_field: Optional[str] = None
371+
) -> torch.Tensor:
362372
"""
363373
Get an example tensor of the shape specified by the Input object
364374
@@ -377,7 +387,9 @@ def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Op
377387
if isinstance(self.shape, tuple):
378388
return torch.rand(self.shape).to(dtype=self.torch_dtype)
379389
else:
380-
RuntimeError(f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})")
390+
RuntimeError(
391+
f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})"
392+
)
381393
else:
382394
if optimization_profile_field is not None:
383395
try:
@@ -397,11 +409,12 @@ def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Op
397409
dtype=self.torch_dtype
398410
)
399411
else:
400-
raise RuntimeError(f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})")
412+
raise RuntimeError(
413+
f"Input shape is dynamic but shapes are not provided as dictionary (found: {self.shape})"
414+
)
401415

402416
else:
403417
raise ValueError(
404418
"Requested an example tensor from a dynamic shaped input but did not specific which profile field to use."
405419
)
406-
return None
407-
420+
raise

py/torch_tensorrt/_compile.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import torch_tensorrt.ts
44

5-
from torch_tensorrt import logging, Input, dtype
5+
from torch_tensorrt import logging
6+
from torch_tensorrt._Input import Input
7+
from torch_tensorrt._enums import dtype
68
import torch
79
import torch.fx
810
from enum import Enum
@@ -12,12 +14,18 @@
1214
from torch_tensorrt.fx.utils import LowerPrecision
1315

1416

15-
def _non_fx_input_interface(inputs: List[Input | torch.Tensor | InputTensorSpec]) -> TypeGuard[List[Input | torch.Tensor]]:
17+
def _non_fx_input_interface(
18+
inputs: List[Input | torch.Tensor | InputTensorSpec],
19+
) -> TypeGuard[List[Input | torch.Tensor]]:
1620
return all([isinstance(i, torch.Tensor | Input) for i in inputs])
1721

18-
def _fx_input_interface(inputs: List[Input | torch.Tensor | InputTensorSpec]) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]:
22+
23+
def _fx_input_interface(
24+
inputs: List[Input | torch.Tensor | InputTensorSpec],
25+
) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]:
1926
return all([isinstance(i, torch.Tensor | InputTensorSpec) for i in inputs])
2027

28+
2129
class _IRType(Enum):
2230
"""Enum to set the minimum required logging level to print a message to stdout"""
2331

@@ -89,10 +97,12 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
8997
def compile(
9098
module: Any,
9199
ir: str = "default",
92-
inputs: List[Union[Input, torch.Tensor, InputTensorSpec]] = [],
93-
enabled_precisions: Set[Union[torch.dtype, dtype]] = set([torch.float]),
100+
inputs: List[Input | torch.Tensor | InputTensorSpec] = [],
101+
enabled_precisions: Set[torch.dtype | dtype] = set([torch.float]),
94102
**kwargs: Any,
95-
) -> Union[torch.nn.Module, torch.jit.ScriptModule, torch.fx.GraphModule, Callable[[Any], Any]]:
103+
) -> (
104+
torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any]
105+
):
96106
"""Compile a PyTorch module for NVIDIA GPUs using TensorRT
97107
98108
Takes a existing PyTorch module and a set of settings to configure the compiler
@@ -168,7 +178,7 @@ def compile(
168178
)
169179
return compiled_fx_module
170180
elif target_ir == _IRType.dynamo:
171-
return torch_tensorrt.dynamo.compile(
181+
compiled_aten_module: torch.fx.GraphModule = torch_tensorrt.dynamo.compile(
172182
module,
173183
inputs=inputs,
174184
enabled_precisions=enabled_precisions,
@@ -198,8 +208,8 @@ def convert_method_to_trt_engine(
198208
module: Any,
199209
method_name: str,
200210
ir: str = "default",
201-
inputs: List[Union[Input, torch.Tensor]] = [],
202-
enabled_precisions: Set[Union[torch.dtype, dtype]] = set([torch.float]),
211+
inputs: List[Input | torch.Tensor] = [],
212+
enabled_precisions: Set[torch.dtype | dtype] = set([torch.float]),
203213
**kwargs: Any,
204214
) -> bytes:
205215
"""Convert a TorchScript module method to a serialized TensorRT engine

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
MAX_AUX_STREAMS = None
99
VERSION_COMPATIBLE = False
1010
OPTIMIZATION_LEVEL = None
11-
USE_PYTHON_RUNTIME = None
11+
USE_PYTHON_RUNTIME = False

py/torch_tensorrt/dynamo/aten_tracer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
specialize_int: bool = True,
4646
verbose: bool = True,
4747
) -> None:
48-
4948
self.capture_scalar_outputs = capture_scalar_outputs
5049
self.guard_nn_modules = guard_nn_modules
5150
self.dynamic_shapes = dynamic_shapes
@@ -128,7 +127,11 @@ def dynamo_trace(
128127

129128

130129
@req_torch_version("2.dev")
131-
def trace(model: Union[torch.nn.Module, torch.fx.GraphModule], inputs: Tuple[Any, ...], **kwargs: Any) -> torch.fx.GraphModule:
130+
def trace(
131+
model: torch.nn.Module | torch.fx.GraphModule,
132+
inputs: Tuple[Any, ...],
133+
**kwargs: Any,
134+
) -> torch.fx.GraphModule:
132135
"""
133136
Optimized trace with necessary passes which re-compose some ops or replace some ops
134137
These passes should be general and functional purpose
@@ -149,11 +152,11 @@ def trace(model: Union[torch.nn.Module, torch.fx.GraphModule], inputs: Tuple[Any
149152
fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
150153
print(fx_module.graph)
151154
for passes in passes_list:
152-
pr: PassResult = passes(fx_module) #type: ignore[assignment] #The type hints in fx are wrong
155+
pr: PassResult = passes(fx_module) # type: ignore[assignment] #The type hints in fx are wrong
153156
fx_module = pr.graph_module
154157

155158
fx_module(*inputs)
156159

157160
fx_module = run_const_fold(fx_module)
158161
print(fx_module.graph)
159-
return fx_module #type: ignore[no-any-return]
162+
return fx_module # type: ignore[no-any-return]

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,19 @@
2424
logger = logging.getLogger(__name__)
2525

2626

27-
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
27+
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
2828
def torch_tensorrt_backend(
29-
gm: torch.fx.GraphModule,
30-
sample_inputs: Sequence[torch.Tensor],
31-
**kwargs: Any
29+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
3230
) -> torch.nn.Module:
3331
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3432

3533
compiled_mod: torch.nn.Module = DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
3634
return compiled_mod
3735

38-
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
36+
37+
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
3938
def aot_torch_tensorrt_aten_backend(
40-
gm: torch.fx.GraphModule,
41-
sample_inputs: Sequence[torch.Tensor],
42-
**kwargs: Any
39+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
4340
) -> torch.nn.Module:
4441
settings = parse_dynamo_kwargs(kwargs)
4542

@@ -55,7 +52,7 @@ def aot_torch_tensorrt_aten_backend(
5552
return aot_module_simplified(
5653
gm,
5754
sample_inputs,
58-
fw_compiler=make_boxed_compiler(custom_backend), # type: ignore[no-untyped-call]
55+
fw_compiler=make_boxed_compiler(custom_backend), # type: ignore[no-untyped-call]
5956
decompositions=get_decompositions(),
6057
)
6158

0 commit comments

Comments
 (0)