Skip to content

Commit 0e9def4

Browse files
committed
chore(//py/torch_tensorrt/dynamo/lowering): mypy compilance
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 810adc1 commit 0e9def4

File tree

6 files changed

+58
-45
lines changed

6 files changed

+58
-45
lines changed

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Optional, Sequence
2+
from typing import Optional, Sequence, Set
33
import torch
44
from torch_tensorrt.dynamo._defaults import (
55
PRECISION,
@@ -20,7 +20,7 @@ class CompilationSettings:
2020
debug: bool = DEBUG
2121
workspace_size: int = WORKSPACE_SIZE
2222
min_block_size: int = MIN_BLOCK_SIZE
23-
torch_executed_ops: Sequence[str] = field(default_factory=set)
23+
torch_executed_ops: Set[str] = field(default_factory=set)
2424
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
2525
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
2626
version_compatible: bool = VERSION_COMPATIBLE

py/torch_tensorrt/dynamo/lowering/_fusers.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,24 @@
22
from torch_tensorrt.fx.tracer.acc_tracer import acc_ops
33

44

5-
def check_permute(node: torch.fx.Node):
5+
def check_permute(node: torch.fx.Node) -> bool:
66
ranks = len(node.meta["tensor_meta"].shape)
77
permutation = list(i % ranks for i in node.kwargs["permutation"]) # type: ignore[union-attr]
88
allowed_permutation = list(i for i in range(ranks))
99
allowed_permutation[-1] = ranks - 2
1010
allowed_permutation[-2] = ranks - 1
1111
return permutation == allowed_permutation
1212

13+
def trt_transposed_matmul(
14+
lhs: torch.Tensor, rhs: torch.Tensor, lhs_transposed: bool, rhs_transposed: bool
15+
) -> torch.Tensor:
16+
if lhs_transposed:
17+
lhs = lhs.transpose(-1, -2)
18+
if rhs_transposed:
19+
rhs = rhs.transpose(-1, -2)
20+
return torch.matmul(lhs, rhs)
1321

14-
def fuse_permute_matmul(gm: torch.fx.GraphModule):
22+
def fuse_permute_matmul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1523
"""
1624
Fuse pattern like permute + matmul if permute is transposing the last two dimension.
1725
"""
@@ -45,11 +53,11 @@ def fuse_permute_matmul(gm: torch.fx.GraphModule):
4553

4654
def trt_transposed_linear(
4755
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
48-
):
56+
) -> torch.Tensor:
4957
return torch.matmul(input.transpose(-1, -2), weight.t()) + bias
5058

5159

52-
def fuse_permute_linear(gm: torch.fx.GraphModule):
60+
def fuse_permute_linear(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
5361
"""
5462
Fuse pattern like permute + linear if permute is transposing the last two dimension.
5563
"""

py/torch_tensorrt/dynamo/lowering/_partition.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Dict, List, Optional, Sequence, Set
2+
from typing import Dict, List, Optional, Sequence, Set, Mapping, Any
33

44
import torch
55

@@ -8,14 +8,14 @@
88
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
99
from torch.fx.graph_module import GraphModule
1010
from torch.fx.node import _get_qualified_name
11-
from torch.fx.passes.operator_support import OperatorSupport
11+
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
1212

1313
from torch_tensorrt.fx.converter_registry import CONVERTERS
1414

1515

1616
logger = logging.getLogger(__name__)
1717

18-
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
18+
DEFAULT_SINGLE_NODE_PARTITIONS: List[str] = list(
1919
_get_qualified_name(to_replace.new_operator)
2020
for to_replace in SUBSTITUTION_REGISTRY.values()
2121
)
@@ -41,10 +41,8 @@ def __init__(
4141
operator_support: OperatorSupport,
4242
*,
4343
non_compute_ops: Optional[Sequence[str]] = None,
44-
allowed_single_node_partition_ops: Optional[
45-
Sequence[str]
46-
] = DEFAULT_SINGLE_NODE_PARTITIONS,
47-
min_block_size=MIN_BLOCK_SIZE,
44+
allowed_single_node_partition_ops: Optional[Sequence[str]] = DEFAULT_SINGLE_NODE_PARTITIONS,
45+
min_block_size: int = MIN_BLOCK_SIZE,
4846
) -> None:
4947
super().__init__(
5048
graph_module,
@@ -74,14 +72,14 @@ def propose_partitions(self) -> List[Partition]:
7472
# Partitions are exempted from min_block_size if they contain an allowed single-node op
7573
if (
7674
node.op == "call_function"
77-
and _get_qualified_name(node.target)
75+
and _get_qualified_name(node.target) # type: ignore[arg-type]
7876
in self.allowed_single_node_partition_ops
7977
):
8078
exempted_partition = True
8179
break
8280
elif (
8381
node.op == "call_function"
84-
and _get_qualified_name(node.target) not in non_compute_ops
82+
and _get_qualified_name(node.target) not in non_compute_ops # type: ignore[arg-type]
8583
):
8684
compute_node_count += 1
8785

@@ -106,16 +104,16 @@ def partition_and_fuse(self) -> GraphModule:
106104
class TorchTensorRTOperatorSupport(OperatorSupport):
107105
"""Class to determine whether operators within a module are supported"""
108106

109-
def __init__(self, support_dict=None, torch_executed_ops=set()):
107+
def __init__(self, support_dict: Optional[SupportDict] = None, torch_executed_ops: Set[str] = set()):
110108
super().__init__(support_dict)
111109

112110
# Initialize sets of supported/unsupported operators
113-
self.supported_operators = set()
114-
self.unsupported_operators = set()
115-
self.torch_executed_ops = torch_executed_ops
111+
self.supported_operators: Set[str] = set()
112+
self.unsupported_operators: Set[str] = set()
113+
self.torch_executed_ops: Set[str] = torch_executed_ops
116114

117115
def is_node_supported(
118-
self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
116+
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
119117
) -> bool:
120118
node_name = (
121119
_get_qualified_name(node.target)
@@ -138,7 +136,7 @@ def is_node_supported(
138136

139137
return False
140138

141-
def print_support_overview(self, num_trt_blocks: Optional[int] = None):
139+
def print_support_overview(self, num_trt_blocks: Optional[int] = None) -> None:
142140
if num_trt_blocks is not None:
143141
logger.debug(
144142
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
@@ -164,7 +162,7 @@ def partition(
164162
gm: torch.fx.GraphModule,
165163
verbose: bool = True,
166164
min_block_size: int = MIN_BLOCK_SIZE,
167-
torch_executed_ops: Sequence[str] = set(),
165+
torch_executed_ops: Set[str] = set(),
168166
) -> torch.fx.GraphModule:
169167
"""Partition an FX GraphModule with aten ops into TRT engines
170168
Partitioning is based on converter operator support
@@ -195,7 +193,7 @@ def get_submod_inputs(
195193
mod: torch.fx.GraphModule,
196194
submod: torch.fx.GraphModule,
197195
inputs: Sequence[torch.Tensor],
198-
) -> Sequence[torch.Tensor]:
196+
) -> Optional[Sequence[torch.Tensor]]:
199197
"""Helper function to get inputs to a Torch submodule
200198
201199
Args:
@@ -205,9 +203,9 @@ def get_submod_inputs(
205203
Returns:
206204
Sequence of Tensors representing inputs to child module
207205
"""
208-
acc_inputs = None
206+
acc_inputs: Optional[Sequence[torch.Tensor]] = None
209207

210-
def get_input(self, inputs):
208+
def get_input(_: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]) -> None:
211209
nonlocal acc_inputs
212210
acc_inputs = inputs
213211

py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Tuple
1+
from typing import Dict, Tuple, Any, List, Optional, Sequence
22
import torch
33
from torch._custom_op.impl import custom_op
44
from torch.fx.node import Argument, Target
@@ -14,18 +14,18 @@
1414
qualname="tensorrt::einsum",
1515
manual_schema="(str equation, Tensor[] tensors) -> Tensor",
1616
)
17-
def einsum(equation, tensors):
17+
def einsum(equation: str, tensors: List[torch.Tensor]) -> torch.Tensor: # type: ignore[empty-body]
1818
# Defines operator schema, name, namespace, and function header
1919
...
2020

2121

22-
@einsum.impl("cpu")
23-
@einsum.impl("cuda")
24-
@einsum.impl_abstract()
22+
@einsum.impl("cpu") # type: ignore[misc]
23+
@einsum.impl("cuda") # type: ignore[misc]
24+
@einsum.impl_abstract() # type: ignore[misc]
2525
def einsum_generic(
26-
*args,
27-
**kwargs,
28-
):
26+
*args: Any,
27+
**kwargs: Any,
28+
) -> Any:
2929
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
3030
return torch.einsum(
3131
*args,
@@ -42,6 +42,7 @@ def aten_ops_einsum(
4242
name: str,
4343
) -> TRTTensor:
4444
# Defines converter replacing the default operator for this function
45+
assert isinstance(args[1], Sequence)
4546
for input_trt in args[1]:
4647
if not isinstance(input_trt, TRTTensor):
4748
raise RuntimeError(f"Einsum received non-TRTTensor input: {input_trt}")
@@ -56,7 +57,7 @@ def aten_ops_einsum(
5657
def einsum_insertion_fn(
5758
gm: torch.fx.GraphModule,
5859
node: torch.fx.Node,
59-
_unused: None = None,
60+
_: Optional[torch.nn.Module] = None,
6061
) -> torch.fx.Node:
6162
equation = node.args[0]
6263

@@ -71,7 +72,7 @@ def einsum_insertion_fn(
7172
), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors"
7273

7374
# Ensure the input is formatted as an equation and
74-
new_node = gm.graph.call_function(
75+
new_node: torch.fx.Node = gm.graph.call_function(
7576
torch.ops.tensorrt.einsum,
7677
args=(equation, inputs),
7778
kwargs=node.kwargs,

py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Tuple
1+
from typing import Dict, Tuple, Any, Optional
22
import torch
33
from torch._custom_op.impl import custom_op
44
from torch.fx.node import Argument, Target
@@ -26,7 +26,7 @@
2626
qualname="tensorrt::maxpool1d",
2727
manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor",
2828
)
29-
def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode):
29+
def maxpool1d(x: torch.Tensor, kernel_size: Tuple[int], stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int], ceil_mode: bool) -> torch.Tensor: # type: ignore[empty-body]
3030
# Defines operator schema, name, namespace, and function header
3131
...
3232

@@ -38,13 +38,13 @@ def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode):
3838
# is desirable. If the operator to replace is a custom module you've written, then add its Torch
3939
# implementation here. Note that the function header to the generic function can have specific arguments
4040
# as in the above placeholder
41-
@maxpool1d.impl("cpu")
42-
@maxpool1d.impl("cuda")
43-
@maxpool1d.impl_abstract()
41+
@maxpool1d.impl("cpu") # type: ignore[misc]
42+
@maxpool1d.impl("cuda") # type: ignore[misc]
43+
@maxpool1d.impl_abstract() # type: ignore[misc]
4444
def maxpool1d_generic(
45-
*args,
46-
**kwargs,
47-
):
45+
*args: Any,
46+
**kwargs: Any,
47+
) -> Any:
4848
# Defines an implementation for AOT Autograd to use for shape analysis/propagation
4949
return torch.nn.functional.max_pool1d(
5050
*args,
@@ -75,10 +75,11 @@ def maxpool1d_generic(
7575
def maxpool1d_insertion_fn(
7676
gm: torch.fx.GraphModule,
7777
node: torch.fx.Node,
78-
submodule: torch.nn.Module,
78+
submodule: Optional[torch.nn.Module],
7979
) -> torch.fx.Node:
8080
# Defines insertion function for new node
81-
new_node = gm.graph.call_function(
81+
assert submodule is not None
82+
new_node: torch.fx.Node = gm.graph.call_function(
8283
torch.ops.tensorrt.maxpool1d,
8384
args=node.args,
8485
kwargs={

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ exclude = [
8383
"docsrc/",
8484
"tests/",
8585
]
86+
python_version = "3.10"
87+
88+
[[tool.mypy.overrides]]
89+
module = "torch_tensorrt.dynamo.lowering._decompositions"
90+
disallow_untyped_calls = false
8691

8792
[tool.setuptools]
8893
package-dir = {"" = "py"}

0 commit comments

Comments
 (0)