4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from typing import Callable , List , Optional
8
+
7
9
import torch
8
10
from torch ._export .utils import get_buffer , get_param , is_buffer , is_param
9
11
from torch ._guards import detect_fake_mode
10
12
from torch .export import ExportedProgram
11
13
from torch .export .exported_program import InputKind , InputSpec , TensorArgument
12
14
13
15
14
- def is_const (arg , exported_program , const_data_list ) -> bool :
16
+ _PRIMITIVE_TYPES = (
17
+ float ,
18
+ int ,
19
+ bool ,
20
+ str ,
21
+ torch .Tensor ,
22
+ torch .device ,
23
+ torch .dtype ,
24
+ torch .layout ,
25
+ )
26
+
27
+
28
+ def is_const (
29
+ arg : object , exported_program : ExportedProgram , const_data_list : List [str ]
30
+ ) -> bool :
15
31
if isinstance (arg , (tuple , list )):
16
32
return all (is_const (x , exported_program , const_data_list ) for x in arg )
17
33
elif isinstance (arg , dict ):
18
34
return all (is_const (x , exported_program , const_data_list ) for x in arg .values ())
35
+ elif isinstance (arg , _PRIMITIVE_TYPES ):
36
+ return True
19
37
elif not isinstance (arg , torch .fx .Node ) or arg .op != "placeholder" :
20
38
return False
21
39
elif (
@@ -27,17 +45,22 @@ def is_const(arg, exported_program, const_data_list) -> bool:
27
45
return False
28
46
29
47
30
- def get_data (exported_program , arg ):
48
+ def get_data (exported_program : ExportedProgram , arg ):
31
49
if isinstance (arg , (tuple , list )):
32
50
return [get_data (exported_program , x ) for x in arg ]
51
+ elif isinstance (arg , _PRIMITIVE_TYPES ):
52
+ return arg
33
53
elif is_param (exported_program , arg ):
34
54
return get_param (exported_program , arg )
35
55
elif is_buffer (exported_program , arg ):
36
56
return get_buffer (exported_program , arg )
37
57
return None
38
58
39
59
40
- def constant_prop_pass (exported_program : ExportedProgram ) -> ExportedProgram :
60
+ def constant_prop_pass (
61
+ exported_program : ExportedProgram ,
62
+ skip_folding_node_fn : Optional [Callable [[torch .fx .Node ], bool ]] = None ,
63
+ ) -> ExportedProgram :
41
64
"""
42
65
This pass is for constant propagation for Exported Program with lifted parameters,
43
66
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
@@ -56,12 +79,14 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
56
79
if len (has_cond ) > 0 :
57
80
raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
58
81
82
+ first_user_input_idx = - 1
59
83
first_user_input = None
60
- for node in exported_program .graph .nodes :
84
+ for i , node in enumerate ( exported_program .graph .nodes ) :
61
85
if (
62
86
node .op == "placeholder"
63
87
and node .name in exported_program .graph_signature .user_inputs
64
88
):
89
+ first_user_input_idx = i
65
90
first_user_input = node
66
91
break
67
92
@@ -79,6 +104,9 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
79
104
assert fake_mode is not None
80
105
81
106
for node in exported_program .graph .nodes :
107
+ if skip_folding_node_fn is not None and skip_folding_node_fn (node ):
108
+ # Do not process this node if we were told to skip it.
109
+ continue
82
110
if node .op == "call_function" :
83
111
constant_data_name_list = [
84
112
input_spec .target for input_spec in prop_constant_data
@@ -115,9 +143,11 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
115
143
exported_program .state_dict [prop_constant_tensor_fqn ] = (
116
144
prop_constant_tensor
117
145
)
118
- exported_program .graph_signature .input_specs .append (
119
- prop_constant_node_input_spec
146
+ # Insert new buffers before the first user input.
147
+ exported_program .graph_signature .input_specs .insert (
148
+ first_user_input_idx , prop_constant_node_input_spec
120
149
)
150
+ first_user_input_idx += 1
121
151
122
152
# Remove the propogated buffer from the state dict
123
153
for node in exported_program .graph .nodes :
@@ -128,6 +158,16 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
128
158
):
129
159
exported_program .state_dict .pop (node .name , None )
130
160
exported_program .graph .erase_node (node )
161
+ # Delete the input spec for this deleted buffer.
162
+ to_erase_idx = []
163
+ for i , spec in enumerate (exported_program .graph_signature .input_specs ):
164
+ if spec .arg .name == node .name :
165
+ to_erase_idx .append (i )
166
+ assert (
167
+ len (to_erase_idx ) == 1
168
+ ), f"Should only delete one spec per node, but deleting multiple: { to_erase_idx } { exported_program .graph_signature .input_specs } "
169
+ for i in reversed (to_erase_idx ):
170
+ exported_program .graph_signature .input_specs .pop (i )
131
171
132
172
exported_program .graph_module .recompile ()
133
173
return exported_program
0 commit comments