Skip to content

Commit 8f27305

Browse files
authored
Make pass_utils pyre-strict
Differential Revision: D70903186 Pull Request resolved: #9092
1 parent 623cd11 commit 8f27305

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

backends/cadence/aot/pass_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from dataclasses import dataclass
10-
from typing import Callable, List, Optional, Set, Type, Union
10+
from typing import Callable, List, Optional, Set, Union
1111

1212
import torch
1313
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
@@ -32,33 +32,33 @@ class CadencePassAttribute:
3232

3333

3434
# A dictionary that maps an ExportPass to its attributes.
35-
ALL_CADENCE_PASSES: dict[Type[ExportPass], CadencePassAttribute] = {}
35+
ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
3636

3737

38-
def get_cadence_pass_attribute(p: Type[ExportPass]) -> CadencePassAttribute:
38+
def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute:
3939
return ALL_CADENCE_PASSES[p]
4040

4141

4242
# A decorator that registers a pass.
4343
def register_cadence_pass(
4444
pass_attribute: CadencePassAttribute,
45-
) -> Callable[[Type[ExportPass]], Type[ExportPass]]:
46-
def wrapper(cls: Type[ExportPass]) -> Type[ExportPass]:
45+
) -> Callable[[ExportPass], ExportPass]:
46+
def wrapper(cls: ExportPass) -> ExportPass:
4747
ALL_CADENCE_PASSES[cls] = pass_attribute
4848
return cls
4949

5050
return wrapper
5151

5252

53-
def get_all_available_cadence_passes() -> Set[Type[ExportPass]]:
53+
def get_all_available_cadence_passes() -> Set[ExportPass]:
5454
return set(ALL_CADENCE_PASSES.keys())
5555

5656

5757
# Create a new filter to filter out relevant passes from all passes.
5858
def create_cadence_pass_filter(
5959
opt_level: int, debug: bool = False
60-
) -> Callable[[Type[ExportPass]], bool]:
61-
def _filter(p: Type[ExportPass]) -> bool:
60+
) -> Callable[[ExportPass], bool]:
61+
def _filter(p: ExportPass) -> bool:
6262
pass_attribute = get_cadence_pass_attribute(p)
6363
return (
6464
pass_attribute.opt_level is not None

backends/cadence/aot/passes.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, cast, List, Optional, Type
9+
from typing import Any, List, Optional
1010

1111
import torch
1212
import torch.fx
@@ -71,7 +71,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
7171
Argument = Any # pyre-ignore
7272

7373

74-
def get_passes_in_default_order() -> List[Type[PassType]]:
74+
def get_passes_in_default_order() -> List[ExportPass]:
7575
passes = [
7676
InitializePipeline,
7777
RemoveRedundantOps.passes,
@@ -95,9 +95,8 @@ def get_cadence_passes(
9595
passes = get_passes_in_default_order()
9696
pass_filter = create_cadence_pass_filter(opt_level)
9797
filtered_passes = [
98+
# pyre-ignore[20]: Expect argument graph_module
9899
filtered_pass()
99-
# pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
100100
for filtered_pass in list(filter(pass_filter, passes))
101101
]
102-
# The type checker can't infer the proper type of the list comprehension.
103-
return cast(List[Optional[PassResult]], filtered_passes)
102+
return filtered_passes

backends/cadence/aot/replace_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,6 +1719,7 @@ def call_operator(self, op, args, kwargs, meta):
17191719
)
17201720

17211721

1722+
# pyre-ignore[6]: Incompatible parameter type (doesn't get the inheritance)
17221723
register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass)
17231724

17241725

0 commit comments

Comments
 (0)