Skip to content

Commit a39d254

Browse files
authored
fix: Linter + config fix (#2636)
1 parent 3390e24 commit a39d254

23 files changed

+163
-152
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ repos:
4747
hooks:
4848
- id: ruff
4949
- repo: https://github.com/psf/black
50-
rev: 23.7.0
50+
rev: 24.1.1
5151
hooks:
5252
- id: black
5353
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs

examples/int8/training/vgg16/vgg16.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
44
https://arxiv.org/abs/1409.1556) (ICLR 2015)
55
"""
6+
7+
from functools import reduce
8+
69
import torch
710
import torch.nn as nn
811
import torch.nn.functional as F
9-
from functools import reduce
1012

1113

1214
class VGG(nn.Module):

py/torch_tensorrt/_Device.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ class Device(object):
3232
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
3333
"""
3434

35-
device_type: Optional[
36-
trt.DeviceType
37-
] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
35+
device_type: Optional[trt.DeviceType] = (
36+
None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
37+
)
3838
gpu_id: int = -1 #: Device ID for target GPU
3939
dla_core: int = -1 #: Core ID for target DLA core
40-
allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
40+
allow_gpu_fallback: bool = (
41+
False #: Whether falling back to GPU if DLA cannot support an op should be allowed
42+
)
4143

4244
def __init__(self, *args: Any, **kwargs: Any):
4345
"""__init__ Method for torch_tensorrt.Device

py/torch_tensorrt/_Input.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ class _ShapeMode(Enum):
2828
STATIC = 0
2929
DYNAMIC = 1
3030

31-
shape_mode: Optional[
32-
_ShapeMode
33-
] = None #: Is input statically or dynamically shaped
34-
shape: Optional[
35-
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
36-
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
31+
shape_mode: Optional[_ShapeMode] = (
32+
None #: Is input statically or dynamically shaped
33+
)
34+
shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
35+
None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
36+
)
3737
dtype: _enums.dtype = (
3838
_enums.dtype.unknown
3939
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
66

77
import torch
8+
import torch_tensorrt
89
from torch.export import ExportedProgram
910
from torch.fx.node import Target
1011
from torch_tensorrt import _enums
@@ -66,8 +67,6 @@
6667
to_torch_tensorrt_device,
6768
)
6869

69-
import torch_tensorrt
70-
7170
logger = logging.getLogger(__name__)
7271

7372

@@ -217,9 +216,9 @@ def compile(
217216
"device": device,
218217
"workspace_size": workspace_size,
219218
"min_block_size": min_block_size,
220-
"torch_executed_ops": torch_executed_ops
221-
if torch_executed_ops is not None
222-
else set(),
219+
"torch_executed_ops": (
220+
torch_executed_ops if torch_executed_ops is not None else set()
221+
),
223222
"pass_through_build_failures": pass_through_build_failures,
224223
"max_aux_streams": max_aux_streams,
225224
"version_compatible": version_compatible,

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828

2929
_LOGGER: logging.Logger = logging.getLogger(__name__)
3030

31-
TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
32-
Callable[[torch.fx.GraphModule], None]
33-
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
31+
TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
32+
Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
33+
)
3434

3535

3636
class UnsupportedOperatorException(RuntimeError):
@@ -92,9 +92,9 @@ def __init__(
9292
self._cur_node: Optional[torch.fx.Node] = None
9393
self._input_names: List[str] = []
9494
self._output_names: List[str] = []
95-
self._itensor_to_tensor_meta: Dict[
96-
trt.tensorrt.ITensor, TensorMetadata
97-
] = dict()
95+
self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
96+
dict()
97+
)
9898
self.compilation_settings = compilation_settings
9999

100100
# Data types for TRT Module output Tensors

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,13 +324,11 @@ def get_trt_tensor(
324324

325325

326326
@overload
327-
def get_positive_dim(dim: int, dim_size: int) -> int:
328-
...
327+
def get_positive_dim(dim: int, dim_size: int) -> int: ...
329328

330329

331330
@overload
332-
def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
333-
...
331+
def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...
334332

335333

336334
def get_positive_dim(

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
aten = torch.ops.aten
99

10-
_core_aten_decompositions: Dict[
11-
OpOverload, Callable[[Any], Any]
12-
] = core_aten_decompositions()
10+
_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
11+
core_aten_decompositions()
12+
)
1313
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
1414
aten._adaptive_avg_pool2d_backward,
1515
aten.addcdiv,
@@ -180,9 +180,9 @@
180180
}
181181

182182

183-
ENABLED_TORCH_DECOMPOSITIONS: Dict[
184-
OpOverload, Callable[[Any], Any]
185-
] = get_torch_decompositions(torch_enabled_decompositions)
183+
ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
184+
get_torch_decompositions(torch_enabled_decompositions)
185+
)
186186
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}
187187

188188

py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@ def lower_linear(
2222
return gm
2323

2424

25-
def linear_replacement() -> (
26-
Tuple[
27-
torch.fx.GraphModule,
28-
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
29-
]
30-
):
25+
def linear_replacement() -> Tuple[
26+
torch.fx.GraphModule,
27+
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
28+
]:
3129
"""Constructs the original and replacement functions for linear"""
3230

3331
# Original graph

py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,10 @@ def lower_scaled_dot_product_attention(
6060
return gm
6161

6262

63-
def scaled_dot_product_attention_replacement() -> (
64-
Tuple[
65-
Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
66-
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
67-
]
68-
):
63+
def scaled_dot_product_attention_replacement() -> Tuple[
64+
Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
65+
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
66+
]:
6967
"""Constructs the original and replacement functions for efficient attention"""
7068

7169
# Efficient Attention original graph

py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@ def view_to_reshape(
2222
return gm
2323

2424

25-
def view_replacement() -> (
26-
Tuple[
27-
torch.fx.GraphModule,
28-
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
29-
]
30-
):
25+
def view_replacement() -> Tuple[
26+
torch.fx.GraphModule,
27+
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
28+
]:
3129
"""Constructs the original and replacement functions for view"""
3230

3331
# Original graph

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import tensorrt as trt
88
import torch
9+
import torch_tensorrt
910
from torch.nn import Module
1011
from torch_tensorrt._Device import Device
1112
from torch_tensorrt.dynamo.runtime.tools import (
@@ -15,8 +16,6 @@
1516
)
1617
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1718

18-
import torch_tensorrt
19-
2019
logger = logging.getLogger(__name__)
2120

2221

@@ -101,9 +100,11 @@ def _initialize(self) -> None:
101100
for idx in self.output_binding_indices_in_order
102101
]
103102
self.output_shapes = [
104-
tuple(self.engine.get_binding_shape(idx))
105-
if self.engine.has_implicit_batch_dimension
106-
else tuple()
103+
(
104+
tuple(self.engine.get_binding_shape(idx))
105+
if self.engine.has_implicit_batch_dimension
106+
else tuple()
107+
)
107108
for idx in self.output_binding_indices_in_order
108109
]
109110
self.hidden_output_dtypes = [
@@ -113,9 +114,11 @@ def _initialize(self) -> None:
113114
for idx in self.hidden_output_binding_indices_in_order
114115
]
115116
self.hidden_output_shapes = [
116-
tuple(self.engine.get_binding_shape(idx))
117-
if self.engine.has_implicit_batch_dimension
118-
else tuple()
117+
(
118+
tuple(self.engine.get_binding_shape(idx))
119+
if self.engine.has_implicit_batch_dimension
120+
else tuple()
121+
)
119122
for idx in self.hidden_output_binding_indices_in_order
120123
]
121124

@@ -167,9 +170,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
167170
self.context = self.engine.create_execution_context()
168171

169172
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
170-
with torch.autograd.profiler.record_function(
171-
"PythonTorchTensorRTModule:Forward"
172-
) if self.profiling_enabled else nullcontext():
173+
with (
174+
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
175+
if self.profiling_enabled
176+
else nullcontext()
177+
):
173178
self._check_initialized()
174179

175180
# If in safe mode, check at each iteration for for whether a switch is required
@@ -200,9 +205,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
200205
inputs = tuple([tensor.to(device) for tensor in inputs])
201206
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
202207

203-
with torch.autograd.profiler.record_function(
204-
"PythonTorchTensorRTModule:ProcessInputs"
205-
) if self.profiling_enabled else nullcontext():
208+
with (
209+
torch.autograd.profiler.record_function(
210+
"PythonTorchTensorRTModule:ProcessInputs"
211+
)
212+
if self.profiling_enabled
213+
else nullcontext()
214+
):
206215
assert len(inputs) == len(
207216
self.input_names
208217
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
@@ -239,9 +248,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
239248
idx, tuple(contiguous_inputs[i].shape)
240249
)
241250

242-
with torch.autograd.profiler.record_function(
243-
"PythonTorchTensorRTModule:ProcessOutputs"
244-
) if self.profiling_enabled else nullcontext():
251+
with (
252+
torch.autograd.profiler.record_function(
253+
"PythonTorchTensorRTModule:ProcessOutputs"
254+
)
255+
if self.profiling_enabled
256+
else nullcontext()
257+
):
245258
# create output tensors
246259
outputs: List[torch.Tensor] = []
247260

@@ -266,9 +279,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
266279
)
267280
bindings[idx] = output.data_ptr()
268281

269-
with torch.autograd.profiler.record_function(
270-
"PythonTorchTensorRTModule:TensorRTRuntime"
271-
) if self.profiling_enabled else nullcontext():
282+
with (
283+
torch.autograd.profiler.record_function(
284+
"PythonTorchTensorRTModule:TensorRTRuntime"
285+
)
286+
if self.profiling_enabled
287+
else nullcontext()
288+
):
272289
self.context.execute_async_v2(
273290
bindings, torch.cuda.current_stream().cuda_stream
274291
)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,22 @@
33
import math
44
import operator
55
import warnings
6-
from typing import cast, Dict, Optional, Sequence, Tuple, Union
6+
from typing import Dict, Optional, Sequence, Tuple, Union, cast
77

88
import numpy as np
99

1010
# @manual=//deeplearning/trt/python:py_tensorrt
1111
import tensorrt as trt
1212
import torch
13+
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
14+
from torch.fx.immutable_collections import immutable_list
15+
from torch.fx.node import Argument, Target
1316
from torch_tensorrt.fx.converters import acc_ops_converters
17+
from torch_tensorrt.fx.converters.impl import activation, convolution
1418

1519
from ..converter_registry import tensorrt_converter
16-
1720
from ..types import * # noqa: F403
18-
from torch.fx.immutable_collections import immutable_list
19-
from torch.fx.node import Argument, Target
20-
2121
from .converter_utils import * # noqa: F403
22-
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
23-
from torch_tensorrt.fx.converters.impl import activation, convolution
2422

2523
_LOGGER: logging.Logger = logging.getLogger(__name__)
2624

@@ -317,21 +315,17 @@ def aten_ops_max_poolnd(
317315
kwargs_new = {
318316
"input": args[0],
319317
"kernel_size": args[1],
320-
"stride": args[2]
321-
if len(args) > 2
322-
else (None, None)
323-
if len(args[1]) == 2
324-
else (None, None, None),
325-
"padding": args[3]
326-
if len(args) > 3
327-
else (0, 0)
328-
if len(args[1]) == 2
329-
else (0, 0, 0),
330-
"dilation": args[4]
331-
if len(args) > 4
332-
else (1, 1)
333-
if len(args[1]) == 2
334-
else (1, 1, 1),
318+
"stride": (
319+
args[2]
320+
if len(args) > 2
321+
else (None, None) if len(args[1]) == 2 else (None, None, None)
322+
),
323+
"padding": (
324+
args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
325+
),
326+
"dilation": (
327+
args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
328+
),
335329
"ceil_mode": args[5] if len(args) > 5 else False,
336330
}
337331
return acc_ops_converters.acc_ops_max_poolnd(

0 commit comments

Comments
 (0)