6
6
7
7
logger = logging .getLogger (__name__ )
8
8
9
+ SubgraphInsertionFnType = Callable [
10
+ [torch .fx .GraphModule , torch .fx .Node , Optional [torch .nn .Module ]], torch .fx .Node
11
+ ]
12
+
9
13
10
14
@dataclass (frozen = True )
11
15
class Substitution :
@@ -18,22 +22,20 @@ class Substitution:
18
22
# and returning a replacement node, with type 'call_function', or raising an Error if
19
23
# incompatibility is detected
20
24
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
21
- subgraph_insertion_fn : Callable [
22
- [torch .fx .GraphModule , torch .fx .Node , Optional [torch .nn .Module ]], torch .fx .Node
23
- ]
25
+ subgraph_insertion_fn : SubgraphInsertionFnType
24
26
25
27
26
28
# Dictionary mapping module to Substitution instance
27
29
SUBSTITUTION_REGISTRY : Dict [
28
- Union [Type [torch .nn .Module ], Callable ], Substitution
30
+ Union [Type [torch .nn .Module ], Callable [..., Any ] ], Substitution
29
31
] = dict ()
30
32
31
33
32
34
def register_substitution (
33
- module_or_function_to_replace : Union [Type [torch .nn .Module ], Callable ],
35
+ module_or_function_to_replace : Union [Type [torch .nn .Module ], Callable [..., Any ] ],
34
36
new_operator : torch ._ops .OpOverload ,
35
37
enabled : bool = True ,
36
- ) -> Callable [[Any ], Any ]:
38
+ ) -> Callable [[SubgraphInsertionFnType ], SubgraphInsertionFnType ]:
37
39
"""Decorator to register subgraph insertion functions
38
40
39
41
Args:
@@ -44,22 +46,22 @@ def register_substitution(
44
46
torch.fx.GraphModule
45
47
"""
46
48
47
- def enable_substitution (subgraph_insertion_fn ) :
49
+ def enable_substitution (subgraph_insertion_fn : SubgraphInsertionFnType ) -> SubgraphInsertionFnType :
48
50
"""Function for use if substitution is enabled"""
49
51
replacement = Substitution (
50
52
new_operator = new_operator , subgraph_insertion_fn = subgraph_insertion_fn
51
53
)
52
54
SUBSTITUTION_REGISTRY [module_or_function_to_replace ] = replacement
53
55
return subgraph_insertion_fn
54
56
55
- def disable_substitution (subgraph_insertion_fn ) :
57
+ def disable_substitution (subgraph_insertion_fn : SubgraphInsertionFnType ) -> SubgraphInsertionFnType :
56
58
"""Function for use if substitution is disabled"""
57
59
return subgraph_insertion_fn
58
60
59
61
return enable_substitution if enabled else disable_substitution
60
62
61
63
62
- def pre_aot_substitutions (gm : torch .fx .GraphModule ):
64
+ def pre_aot_substitutions (gm : torch .fx .GraphModule ) -> torch . fx . GraphModule :
63
65
"""Perform graph substitutions prior to AOT tracing
64
66
65
67
Args:
@@ -92,6 +94,7 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule):
92
94
# If submodule/function is a member of the substitution registry, replace it
93
95
if exists_in_registry :
94
96
try :
97
+ assert to_replace is not None
95
98
replacement = SUBSTITUTION_REGISTRY [to_replace ]
96
99
op , insertion_fn = (
97
100
replacement .new_operator ,
0 commit comments