4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from collections import OrderedDict
8
+ from typing import cast , Mapping
9
+
7
10
import torch
8
- from torch ._export .utils import get_buffer , get_param , is_buffer , is_param
11
+ from executorch .exir .dialects ._ops import ops as exir_ops
12
+ from torch ._export .utils import (
13
+ get_buffer ,
14
+ get_lifted_tensor_constant ,
15
+ get_param ,
16
+ is_buffer ,
17
+ is_lifted_tensor_constant ,
18
+ is_param ,
19
+ )
9
20
from torch ._guards import detect_fake_mode
10
21
from torch .export import ExportedProgram
11
22
from torch .export .exported_program import InputKind , InputSpec , TensorArgument
23
+ from torch .utils import _pytree as pytree
24
+
25
+
26
+ # Avoid propagating constants for `exir.ops.edge.aten.full.default`.
27
+ # Propagating aten.full can significantly increase compiled model size.
28
+ SKIP_TARGETS = {exir_ops .edge .aten .full .default }
12
29
13
30
14
- def is_const (arg , exported_program , const_data_list ) -> bool :
31
+ def is_const (
32
+ arg ,
33
+ exported_program : ExportedProgram ,
34
+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
35
+ allow_propagation_of_scalars : bool ,
36
+ ) -> bool :
15
37
if isinstance (arg , (tuple , list )):
16
- return all (is_const (x , exported_program , const_data_list ) for x in arg )
38
+ return all (
39
+ is_const (
40
+ x , exported_program , const_node_to_tensor , allow_propagation_of_scalars
41
+ )
42
+ for x in arg
43
+ )
17
44
elif isinstance (arg , dict ):
18
- return all (is_const (x , exported_program , const_data_list ) for x in arg .values ())
19
- elif not isinstance (arg , torch .fx .Node ) or arg .op != "placeholder" :
45
+ return all (
46
+ is_const (
47
+ x , exported_program , const_node_to_tensor , allow_propagation_of_scalars
48
+ )
49
+ for x in arg .values ()
50
+ )
51
+ elif isinstance (arg , (int , float , bool , str , torch .dtype )):
52
+ return allow_propagation_of_scalars
53
+ elif not isinstance (arg , torch .fx .Node ):
20
54
return False
21
- elif (
22
- is_param (exported_program , arg )
23
- or is_buffer (exported_program , arg )
24
- or arg .name in const_data_list
25
- ):
55
+ elif arg in const_node_to_tensor :
26
56
return True
27
57
return False
28
58
29
59
30
- def get_data (exported_program , arg ):
60
+ def get_data (
61
+ arg ,
62
+ exported_program : ExportedProgram ,
63
+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
64
+ ):
31
65
if isinstance (arg , (tuple , list )):
32
- return [get_data (exported_program , x ) for x in arg ]
33
- elif is_param ( exported_program , arg ):
34
- return get_param ( exported_program , arg )
35
- elif is_buffer ( exported_program , arg ) :
36
- return get_buffer ( exported_program , arg )
66
+ return [get_data (x , exported_program , const_node_to_tensor ) for x in arg ]
67
+ elif isinstance ( arg , ( int , float , bool , str , torch . dtype ) ):
68
+ return arg
69
+ elif arg in const_node_to_tensor :
70
+ return const_node_to_tensor [ arg ]
37
71
return None
38
72
39
73
40
- def constant_prop_pass (exported_program : ExportedProgram ) -> ExportedProgram :
74
+ def get_constant_placeholder_dict (
75
+ exported_program : ExportedProgram ,
76
+ ) -> OrderedDict [torch .fx .Node , torch .Tensor ]:
41
77
"""
42
- This pass is for constant propagation for Exported Program with lifted parameters,
43
- as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
78
+ Returns a dictionary of placeholder node -> constant tensor.
44
79
"""
45
- if (
46
- len ([node for node in exported_program .graph .nodes if node .op == "placeholder" ])
47
- == 0
48
- ):
49
- return exported_program
80
+ const_node_to_tensor : OrderedDict [torch .fx .Node , torch .Tensor ] = OrderedDict ()
81
+ for node in exported_program .graph .nodes :
82
+ if node .op != "placeholder" :
83
+ continue
84
+
85
+ if is_param (exported_program , node ):
86
+ const_node_to_tensor [node ] = cast (
87
+ torch .Tensor , get_param (exported_program , node )
88
+ )
89
+ elif is_buffer (exported_program , node ):
90
+ const_node_to_tensor [node ] = cast (
91
+ torch .Tensor , get_buffer (exported_program , node )
92
+ )
93
+ elif is_lifted_tensor_constant (exported_program , node ):
94
+ const_node_to_tensor [node ] = cast (
95
+ torch .Tensor , get_lifted_tensor_constant (exported_program , node )
96
+ )
97
+ return const_node_to_tensor
50
98
51
- has_cond = [
52
- node
53
- for node in exported_program .graph .nodes
54
- if node .target == torch .ops .higher_order .cond
55
- ]
56
- if len (has_cond ) > 0 :
57
- raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
58
99
100
+ def get_propagated_const_tensor_dict (
101
+ exported_program : ExportedProgram ,
102
+ allow_propagation_of_scalars : bool ,
103
+ ) -> OrderedDict [torch .fx .Node , torch .Tensor ]:
104
+ """
105
+ Propagates constants and returns a dictionary of node->constant tensors.
106
+ """
107
+ # Initialize dict with all constant placeholders.
108
+ const_node_to_tensor = get_constant_placeholder_dict (exported_program )
109
+
110
+ for node in exported_program .graph .nodes :
111
+ if node .op == "placeholder" :
112
+ continue
113
+
114
+ if node .op != "call_function" or node .target in SKIP_TARGETS :
115
+ continue
116
+
117
+ if not is_const (
118
+ node .args ,
119
+ exported_program ,
120
+ const_node_to_tensor ,
121
+ allow_propagation_of_scalars = allow_propagation_of_scalars ,
122
+ ):
123
+ continue
124
+
125
+ args_data , kwargs_data = pytree .tree_map (
126
+ lambda x : get_data (x , exported_program , const_node_to_tensor ),
127
+ (node .args , node .kwargs ),
128
+ )
129
+
130
+ # Execute the `node.target` and create a new propagated constant tensor.
131
+ prop_constant_tensor = node .target (* args_data , ** kwargs_data )
132
+ const_node_to_tensor [node ] = prop_constant_tensor
133
+
134
+ return const_node_to_tensor
135
+
136
+
137
+ def get_first_user_input (exported_program : ExportedProgram ) -> torch .fx .Node :
138
+ """Returns the first user input node in the graph."""
59
139
first_user_input = None
60
140
for node in exported_program .graph .nodes :
61
141
if (
@@ -64,11 +144,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
64
144
):
65
145
first_user_input = node
66
146
break
147
+ return first_user_input
148
+
149
+
150
+ def replace_with_constant_node (
151
+ node : torch .fx .Node ,
152
+ prop_constant_tensor : torch .Tensor ,
153
+ first_user_input : torch .fx .Node ,
154
+ fake_mode ,
155
+ exported_program : ExportedProgram ,
156
+ ) -> tuple [torch .fx .Node , str ]:
157
+ # Add `prop_constant_tensor` to program.constants.
158
+ prop_constant_tensor_fqn = f"_prop_tensor_constant{ len (exported_program .constants )} "
159
+ exported_program .constants [prop_constant_tensor_fqn ] = prop_constant_tensor
160
+
161
+ # Insert a new placeholder node for the propagated constant tensor.
162
+ with exported_program .graph .inserting_before (first_user_input ):
163
+ const_placeholder_node = exported_program .graph .placeholder (
164
+ prop_constant_tensor_fqn
165
+ )
166
+
167
+ # Update the meta data of the new placeholder (buffer) node.
168
+ for k , v in node .meta .items ():
169
+ const_placeholder_node .meta [k ] = v
170
+ const_placeholder_node .meta ["val" ] = fake_mode .from_tensor (
171
+ prop_constant_tensor , static_shapes = True
172
+ )
173
+ const_placeholder_node .meta ["val" ].constant = prop_constant_tensor
174
+
175
+ # Replace the original node with the new constant node.
176
+ node .replace_all_uses_with (const_placeholder_node )
177
+ exported_program .graph .erase_node (node )
67
178
68
- buffers = exported_program .graph_signature .buffers
69
- prop_constant_data = []
70
- const_data_to_be_removed = set ()
179
+ return const_placeholder_node , prop_constant_tensor_fqn
71
180
181
+
182
+ def get_fake_mode (exported_program : ExportedProgram ):
72
183
fake_mode = detect_fake_mode (
73
184
tuple (
74
185
node .meta ["val" ]
@@ -77,57 +188,101 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
77
188
)
78
189
)
79
190
assert fake_mode is not None
191
+ return fake_mode
80
192
193
+
194
+ def erase_constant_node (
195
+ exported_program : ExportedProgram ,
196
+ node : torch .fx .Node ,
197
+ ):
198
+ # Remove from graph.
199
+ exported_program .graph .erase_node (node )
200
+
201
+ # Remove corresponding tensor from param/constants dict.
202
+ signature = exported_program .graph_signature
203
+ if name := signature .inputs_to_parameters .pop (node .name , None ):
204
+ exported_program .state_dict .pop (name , None )
205
+ elif name := signature .inputs_to_lifted_tensor_constants .pop (node .name , None ):
206
+ exported_program .constants .pop (name , None )
207
+
208
+
209
+ def create_constant_nodes_and_return_specs (
210
+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
211
+ exported_program : ExportedProgram ,
212
+ ) -> dict [str , InputSpec ]:
213
+ """
214
+ Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
215
+ """
216
+ name_to_spec_dict : dict [str , InputSpec ] = {}
217
+
218
+ fake_mode = get_fake_mode (exported_program )
219
+ first_user_input = get_first_user_input (exported_program )
220
+
221
+ # Iterate over nodes in reverse order.
222
+ for node , prop_constant_tensor in reversed (const_node_to_tensor .items ()):
223
+ if all (x in const_node_to_tensor for x in node .users ):
224
+ # All users of this constant node are also constant, so we don't need to create a new constant node.
225
+ erase_constant_node (exported_program , node )
226
+ continue
227
+
228
+ const_placeholder_node , prop_constant_tensor_fqn = replace_with_constant_node (
229
+ node , prop_constant_tensor , first_user_input , fake_mode , exported_program
230
+ )
231
+
232
+ # Create input spec for lifted constant.
233
+ name_to_spec_dict [const_placeholder_node .name ] = InputSpec (
234
+ kind = InputKind .CONSTANT_TENSOR ,
235
+ arg = TensorArgument (name = const_placeholder_node .name ),
236
+ target = prop_constant_tensor_fqn ,
237
+ persistent = True ,
238
+ )
239
+ return name_to_spec_dict
240
+
241
+
242
+ def constant_prop_pass (
243
+ exported_program : ExportedProgram , allow_propagation_of_scalars : bool = False
244
+ ) -> ExportedProgram :
245
+ """
246
+ This pass is for constant propagation for Exported Program with lifted parameters,
247
+ as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
248
+ """
249
+ if (
250
+ len ([node for node in exported_program .graph .nodes if node .op == "placeholder" ])
251
+ == 0
252
+ ):
253
+ return exported_program
254
+
255
+ has_control_flow = [
256
+ node
257
+ for node in exported_program .graph .nodes
258
+ if node .target == torch .ops .higher_order .cond
259
+ ]
260
+ if len (has_control_flow ) > 0 :
261
+ raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
262
+
263
+ const_node_to_tensor = get_propagated_const_tensor_dict (
264
+ exported_program , allow_propagation_of_scalars = allow_propagation_of_scalars
265
+ )
266
+
267
+ # Get old input specs.
268
+ name_to_spec_dict = {
269
+ s .arg .name : s for s in exported_program .graph_signature .input_specs
270
+ }
271
+ # Add the new constants to input specs dict.
272
+ name_to_spec_dict .update (
273
+ create_constant_nodes_and_return_specs (const_node_to_tensor , exported_program )
274
+ )
275
+
276
+ # Generate new input spec.
277
+ new_input_specs = []
81
278
for node in exported_program .graph .nodes :
82
- if node .op == "call_function" :
83
- constant_data_name_list = [
84
- input_spec .target for input_spec in prop_constant_data
85
- ]
86
- if is_const (node .args , exported_program , constant_data_name_list ):
87
- args_data = [get_data (exported_program , arg ) for arg in node .args ]
88
- kwargs_data = node .kwargs
89
- const_data_to_be_removed .update (node .args )
90
- prop_constant_tensor = node .target (* args_data , ** kwargs_data )
91
- prop_constant_tensor_fqn = f"_prop_tensor_constant{ len (buffers )} "
92
-
93
- with exported_program .graph .inserting_before (first_user_input ):
94
- const_placeholder_node = exported_program .graph .placeholder (
95
- prop_constant_tensor_fqn
96
- )
97
- # Update the meta data of the new placeholder (buffer) node
98
- for k , v in node .meta .items ():
99
- const_placeholder_node .meta [k ] = v
100
- const_placeholder_node .meta ["val" ] = fake_mode .from_tensor (
101
- prop_constant_tensor , static_shapes = True
102
- )
103
- const_placeholder_node .meta ["val" ].constant = prop_constant_tensor
104
-
105
- node .replace_all_uses_with (const_placeholder_node )
106
- exported_program .graph .erase_node (node )
107
- prop_constant_node_input_spec = InputSpec (
108
- kind = InputKind .BUFFER ,
109
- arg = TensorArgument (name = const_placeholder_node .name ),
110
- target = prop_constant_tensor_fqn ,
111
- persistent = True ,
112
- )
113
- prop_constant_data .append (prop_constant_node_input_spec )
114
- buffers .append (prop_constant_tensor_fqn )
115
- exported_program .state_dict [prop_constant_tensor_fqn ] = (
116
- prop_constant_tensor
117
- )
118
- exported_program .graph_signature .input_specs .append (
119
- prop_constant_node_input_spec
120
- )
121
-
122
- # Remove the propogated buffer from the state dict
123
- for node in exported_program .graph .nodes :
124
- if (
125
- node .op == "placeholder"
126
- and node in const_data_to_be_removed
127
- and len (node .users ) == 0
128
- ):
129
- exported_program .state_dict .pop (node .name , None )
130
- exported_program .graph .erase_node (node )
279
+ if node .op != "placeholder" :
280
+ continue
281
+ new_input_specs .append (name_to_spec_dict [node .name ])
282
+ exported_program .graph_signature .input_specs = new_input_specs
131
283
284
+ # Cleanup the graph.
285
+ exported_program .graph .eliminate_dead_code ()
132
286
exported_program .graph_module .recompile ()
287
+
133
288
return exported_program
0 commit comments