7
7
# pyre-strict
8
8
9
9
from dataclasses import dataclass
10
- from typing import Callable , List , Optional , Set , Type , Union
10
+ from typing import Callable , List , Optional , Set , Union
11
11
12
12
import torch
13
13
from executorch .backends .cadence .aot .utils import get_edge_overload_packet
@@ -32,33 +32,33 @@ class CadencePassAttribute:
32
32
33
33
34
34
# A dictionary that maps an ExportPass to its attributes.
35
- ALL_CADENCE_PASSES : dict [Type [ ExportPass ] , CadencePassAttribute ] = {}
35
+ ALL_CADENCE_PASSES : dict [ExportPass , CadencePassAttribute ] = {}
36
36
37
37
38
- def get_cadence_pass_attribute (p : Type [ ExportPass ] ) -> CadencePassAttribute :
38
+ def get_cadence_pass_attribute (p : ExportPass ) -> CadencePassAttribute :
39
39
return ALL_CADENCE_PASSES [p ]
40
40
41
41
42
42
# A decorator that registers a pass.
43
43
def register_cadence_pass (
44
44
pass_attribute : CadencePassAttribute ,
45
- ) -> Callable [[Type [ ExportPass ]], Type [ ExportPass ] ]:
46
- def wrapper (cls : Type [ ExportPass ] ) -> Type [ ExportPass ] :
45
+ ) -> Callable [[ExportPass ], ExportPass ]:
46
+ def wrapper (cls : ExportPass ) -> ExportPass :
47
47
ALL_CADENCE_PASSES [cls ] = pass_attribute
48
48
return cls
49
49
50
50
return wrapper
51
51
52
52
53
- def get_all_available_cadence_passes () -> Set [Type [ ExportPass ] ]:
53
+ def get_all_available_cadence_passes () -> Set [ExportPass ]:
54
54
return set (ALL_CADENCE_PASSES .keys ())
55
55
56
56
57
57
# Create a new filter to filter out relevant passes from all passes.
58
58
def create_cadence_pass_filter (
59
59
opt_level : int , debug : bool = False
60
- ) -> Callable [[Type [ ExportPass ] ], bool ]:
61
- def _filter (p : Type [ ExportPass ] ) -> bool :
60
+ ) -> Callable [[ExportPass ], bool ]:
61
+ def _filter (p : ExportPass ) -> bool :
62
62
pass_attribute = get_cadence_pass_attribute (p )
63
63
return (
64
64
pass_attribute .opt_level is not None
0 commit comments