Skip to content

Commit 1b10444

Browse files
committed
chore(//py/torch_tensorrt/dynamo/backend): Backend is mypy conforming
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 790c78f commit 1b10444

File tree

2 files changed

+29
-25
lines changed

2 files changed

+29
-25
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Sequence
2+
from typing import Sequence, Any, Callable
33
import torch
44
from functools import partial
55
import torch._dynamo as td
@@ -24,19 +24,23 @@
2424
logger = logging.getLogger(__name__)
2525

2626

27-
@td.register_backend(name="torch_tensorrt")
27+
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
2828
def torch_tensorrt_backend(
29-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
30-
):
29+
gm: torch.fx.GraphModule,
30+
sample_inputs: Sequence[torch.Tensor],
31+
**kwargs: Any
32+
) -> torch.nn.Module:
3133
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3234

33-
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
34-
35+
compiled_mod: torch.nn.Module = DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
36+
return compiled_mod
3537

36-
@td.register_backend(name="aot_torch_tensorrt_aten")
38+
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
3739
def aot_torch_tensorrt_aten_backend(
38-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
39-
):
40+
gm: torch.fx.GraphModule,
41+
sample_inputs: Sequence[torch.Tensor],
42+
**kwargs: Any
43+
) -> torch.nn.Module:
4044
settings = parse_dynamo_kwargs(kwargs)
4145

4246
custom_backend = partial(
@@ -51,7 +55,7 @@ def aot_torch_tensorrt_aten_backend(
5155
return aot_module_simplified(
5256
gm,
5357
sample_inputs,
54-
fw_compiler=make_boxed_compiler(custom_backend),
58+
fw_compiler=make_boxed_compiler(custom_backend), # type: ignore[no-untyped-call]
5559
decompositions=get_decompositions(),
5660
)
5761

@@ -60,7 +64,7 @@ def _pretraced_backend(
6064
gm: torch.fx.GraphModule,
6165
sample_inputs: Sequence[torch.Tensor],
6266
settings: CompilationSettings = CompilationSettings(),
63-
):
67+
) -> torch.fx.GraphModule | Callable[..., Any]:
6468
"""Helper function to manage translation of traced FX module to TRT engines
6569
6670
Args:
Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1+
from typing import Callable, Dict, Any
12
import torch
2-
from torch._decomp import register_decomposition, core_aten_decompositions
3+
from torch._decomp import register_decomposition, core_aten_decompositions, OpOverload
34

45

5-
DECOMPOSITIONS = {**core_aten_decompositions()}
6+
DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {**core_aten_decompositions()}
67

78
aten = torch.ops.aten
89

9-
10-
def replace_inplace_op(aten_op, outplace_op):
10+
def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any:
1111
"""Replace inplace operation with functional equivalent
1212
Adapted from:
1313
https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361
1414
"""
1515

16-
@register_decomposition(aten_op, registry=DECOMPOSITIONS)
17-
def inplace_op(*args, **kwargs):
16+
@register_decomposition(aten_op, registry=DECOMPOSITIONS) # type: ignore[misc]
17+
def inplace_op(*args: Any, **kwargs: Any) -> Any:
1818
out = outplace_op(*args, **kwargs)
1919
return args[0].copy_(out)
2020

@@ -36,29 +36,29 @@ def inplace_op(*args, **kwargs):
3636
replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)
3737

3838

39-
@register_decomposition(aten.std, registry=DECOMPOSITIONS)
40-
def std_replacement(*args, **kwargs) -> torch.Tensor:
39+
@register_decomposition(aten.std, registry=DECOMPOSITIONS) # type: ignore[misc]
40+
def std_replacement(*args: Any, **kwargs: Any) -> torch.Tensor:
4141
return torch.sqrt(torch.var(*args, **kwargs))
4242

4343

44-
@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS)
45-
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor:
44+
@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS) # type: ignore[misc]
45+
def rsqrt_replacement(*args: Any, **kwargs: Any) -> torch.Tensor:
4646
return torch.reciprocal(torch.sqrt(*args, **kwargs))
4747

4848

49-
@register_decomposition(aten.alias, registry=DECOMPOSITIONS)
49+
@register_decomposition(aten.alias, registry=DECOMPOSITIONS) # type: ignore[misc]
5050
def alias_replacement(x: torch.Tensor) -> torch.Tensor:
5151
return x
5252

5353

54-
@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS)
54+
@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS) # type: ignore[misc]
5555
def addmm_replacement(
56-
input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1
56+
input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta: int = 1, alpha: int = 1
5757
) -> torch.Tensor:
5858
return torch.add(
5959
torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha)
6060
)
6161

6262

63-
def get_decompositions():
63+
def get_decompositions() -> Dict[OpOverload, Callable[[Any], Any]]:
6464
return DECOMPOSITIONS

0 commit comments

Comments
 (0)