Skip to content

Commit 4eb9361

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Make pass_utils pyre-strict
Summary: As titled. Requires the removal of some unneeded `Type[]` markers in the OSS aot file too. Differential Revision: D70903186
1 parent 2094659 commit 4eb9361

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

backends/cadence/aot/pass_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from dataclasses import dataclass
10-
from typing import Callable, List, Optional, Set, Type, Union
10+
from typing import Callable, List, Optional, Set, Union
1111

1212
import torch
1313
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
@@ -32,33 +32,33 @@ class CadencePassAttribute:
3232

3333

3434
# A dictionary that maps an ExportPass to its attributes.
35-
ALL_CADENCE_PASSES: dict[Type[ExportPass], CadencePassAttribute] = {}
35+
ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
3636

3737

38-
def get_cadence_pass_attribute(p: Type[ExportPass]) -> CadencePassAttribute:
38+
def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute:
3939
return ALL_CADENCE_PASSES[p]
4040

4141

4242
# A decorator that registers a pass.
4343
def register_cadence_pass(
4444
pass_attribute: CadencePassAttribute,
45-
) -> Callable[[Type[ExportPass]], Type[ExportPass]]:
46-
def wrapper(cls: Type[ExportPass]) -> Type[ExportPass]:
45+
) -> Callable[[ExportPass], ExportPass]:
46+
def wrapper(cls: ExportPass) -> ExportPass:
4747
ALL_CADENCE_PASSES[cls] = pass_attribute
4848
return cls
4949

5050
return wrapper
5151

5252

53-
def get_all_available_cadence_passes() -> Set[Type[ExportPass]]:
53+
def get_all_available_cadence_passes() -> Set[ExportPass]:
5454
return set(ALL_CADENCE_PASSES.keys())
5555

5656

5757
# Create a new filter to filter out relevant passes from all passes.
5858
def create_cadence_pass_filter(
5959
opt_level: int, debug: bool = False
60-
) -> Callable[[Type[ExportPass]], bool]:
61-
def _filter(p: Type[ExportPass]) -> bool:
60+
) -> Callable[[ExportPass], bool]:
61+
def _filter(p: ExportPass) -> bool:
6262
pass_attribute = get_cadence_pass_attribute(p)
6363
return (
6464
pass_attribute.opt_level is not None

0 commit comments

Comments
 (0)