Skip to content

Commit 810adc1

Browse files
committed
chore(//py/torch_tensorrt): subsitutions mypy compliant
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 1b10444 commit 810adc1

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch._decomp import register_decomposition, core_aten_decompositions, OpOverload
44

55

6-
DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {**core_aten_decompositions()}
6+
DECOMPOSITIONS: Dict[OpOverload, Callable[..., Any]] = {**core_aten_decompositions()}
77

88
aten = torch.ops.aten
99

@@ -60,5 +60,5 @@ def addmm_replacement(
6060
)
6161

6262

63-
def get_decompositions() -> Dict[OpOverload, Callable[[Any], Any]]:
63+
def get_decompositions() -> Dict[OpOverload, Callable[..., Any]]:
6464
return DECOMPOSITIONS

py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
logger = logging.getLogger(__name__)
88

9+
SubgraphInsertionFnType = Callable[
10+
[torch.fx.GraphModule, torch.fx.Node, Optional[torch.nn.Module]], torch.fx.Node
11+
]
12+
913

1014
@dataclass(frozen=True)
1115
class Substitution:
@@ -18,22 +22,20 @@ class Substitution:
1822
# and returning a replacement node, with type 'call_function', or raising an Error if
1923
# incompatibility is detected
2024
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
21-
subgraph_insertion_fn: Callable[
22-
[torch.fx.GraphModule, torch.fx.Node, Optional[torch.nn.Module]], torch.fx.Node
23-
]
25+
subgraph_insertion_fn: SubgraphInsertionFnType
2426

2527

2628
# Dictionary mapping module to Substitution instance
2729
SUBSTITUTION_REGISTRY: Dict[
28-
Union[Type[torch.nn.Module], Callable], Substitution
30+
Union[Type[torch.nn.Module], Callable[..., Any]], Substitution
2931
] = dict()
3032

3133

3234
def register_substitution(
33-
module_or_function_to_replace: Union[Type[torch.nn.Module], Callable],
35+
module_or_function_to_replace: Union[Type[torch.nn.Module], Callable[..., Any]],
3436
new_operator: torch._ops.OpOverload,
3537
enabled: bool = True,
36-
) -> Callable[[Any], Any]:
38+
) -> Callable[[SubgraphInsertionFnType], SubgraphInsertionFnType]:
3739
"""Decorator to register subgraph insertion functions
3840
3941
Args:
@@ -44,22 +46,22 @@ def register_substitution(
4446
torch.fx.GraphModule
4547
"""
4648

47-
def enable_substitution(subgraph_insertion_fn):
49+
def enable_substitution(subgraph_insertion_fn: SubgraphInsertionFnType) -> SubgraphInsertionFnType:
4850
"""Function for use if substitution is enabled"""
4951
replacement = Substitution(
5052
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
5153
)
5254
SUBSTITUTION_REGISTRY[module_or_function_to_replace] = replacement
5355
return subgraph_insertion_fn
5456

55-
def disable_substitution(subgraph_insertion_fn):
57+
def disable_substitution(subgraph_insertion_fn: SubgraphInsertionFnType) -> SubgraphInsertionFnType:
5658
"""Function for use if substitution is disabled"""
5759
return subgraph_insertion_fn
5860

5961
return enable_substitution if enabled else disable_substitution
6062

6163

62-
def pre_aot_substitutions(gm: torch.fx.GraphModule):
64+
def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
6365
"""Perform graph substitutions prior to AOT tracing
6466
6567
Args:
@@ -92,6 +94,7 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule):
9294
# If submodule/function is a member of the substitution registry, replace it
9395
if exists_in_registry:
9496
try:
97+
assert to_replace is not None
9598
replacement = SUBSTITUTION_REGISTRY[to_replace]
9699
op, insertion_fn = (
97100
replacement.new_operator,

0 commit comments

Comments
 (0)