4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from collections import OrderedDict
8
+ from typing import cast , Mapping , Optional
9
+
7
10
import torch
8
- from torch ._export .utils import get_buffer , get_param , is_buffer , is_param
11
+ from executorch .exir .dialects ._ops import ops as exir_ops
12
+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload
13
+ from torch ._export .utils import (
14
+ get_buffer ,
15
+ get_lifted_tensor_constant ,
16
+ get_param ,
17
+ is_buffer ,
18
+ is_lifted_tensor_constant ,
19
+ is_param ,
20
+ )
9
21
from torch ._guards import detect_fake_mode
10
22
from torch .export import ExportedProgram
11
23
from torch .export .exported_program import InputKind , InputSpec , TensorArgument
24
+ from torch .utils import _pytree as pytree
25
+
26
+
27
+ # Avoid propagating constants for `exir.ops.edge.aten.full.default`.
28
+ # Propagating aten.full can significantly increase compiled model size.
29
+ _DEFAULT_SKIP_TARGETS = {exir_ops .edge .aten .full .default }
12
30
31
+ _PRIMITIVE_TYPES = (
32
+ float ,
33
+ int ,
34
+ bool ,
35
+ str ,
36
+ torch .Tensor ,
37
+ torch .device ,
38
+ torch .dtype ,
39
+ torch .layout ,
40
+ )
13
41
14
- def is_const (arg , exported_program , const_data_list ) -> bool :
42
+
43
+ def is_const (
44
+ arg ,
45
+ exported_program : ExportedProgram ,
46
+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
47
+ ) -> bool :
15
48
if isinstance (arg , (tuple , list )):
16
- return all (is_const (x , exported_program , const_data_list ) for x in arg )
49
+ return all (is_const (x , exported_program , const_node_to_tensor ) for x in arg )
17
50
elif isinstance (arg , dict ):
18
- return all (is_const (x , exported_program , const_data_list ) for x in arg .values ())
19
- elif not isinstance (arg , torch .fx .Node ) or arg .op != "placeholder" :
51
+ return all (
52
+ is_const (x , exported_program , const_node_to_tensor ) for x in arg .values ()
53
+ )
54
+ elif isinstance (arg , _PRIMITIVE_TYPES ):
55
+ return True
56
+ elif not isinstance (arg , torch .fx .Node ):
20
57
return False
21
- elif (
22
- is_param (exported_program , arg )
23
- or is_buffer (exported_program , arg )
24
- or arg .name in const_data_list
25
- ):
58
+ elif arg in const_node_to_tensor :
26
59
return True
27
60
return False
28
61
29
62
30
- def get_data (exported_program , arg ):
63
+ def get_data (
64
+ arg ,
65
+ exported_program : ExportedProgram ,
66
+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
67
+ ):
31
68
if isinstance (arg , (tuple , list )):
32
- return [get_data (exported_program , x ) for x in arg ]
33
- elif is_param ( exported_program , arg ):
34
- return get_param ( exported_program , arg )
35
- elif is_buffer ( exported_program , arg ) :
36
- return get_buffer ( exported_program , arg )
69
+ return [get_data (x , exported_program , const_node_to_tensor ) for x in arg ]
70
+ elif isinstance ( arg , _PRIMITIVE_TYPES ):
71
+ return arg
72
+ elif arg in const_node_to_tensor :
73
+ return const_node_to_tensor [ arg ]
37
74
return None
38
75
39
76
40
- def constant_prop_pass (exported_program : ExportedProgram ) -> ExportedProgram :
77
+ def get_constant_placeholder_dict (
78
+ exported_program : ExportedProgram ,
79
+ ) -> OrderedDict [torch .fx .Node , torch .Tensor ]:
41
80
"""
42
- This pass is for constant propagation for Exported Program with lifted parameters,
43
- as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
81
+ Returns a dictionary of placeholder node -> constant tensor.
44
82
"""
45
- if (
46
- len ([node for node in exported_program .graph .nodes if node .op == "placeholder" ])
47
- == 0
48
- ):
49
- return exported_program
83
+ const_node_to_tensor : OrderedDict [torch .fx .Node , torch .Tensor ] = OrderedDict ()
84
+ for node in exported_program .graph .nodes :
85
+ if node .op != "placeholder" :
86
+ continue
87
+
88
+ if is_param (exported_program , node ):
89
+ const_node_to_tensor [node ] = cast (
90
+ torch .Tensor , get_param (exported_program , node )
91
+ )
92
+ elif is_buffer (exported_program , node ):
93
+ const_node_to_tensor [node ] = cast (
94
+ torch .Tensor , get_buffer (exported_program , node )
95
+ )
96
+ elif is_lifted_tensor_constant (exported_program , node ):
97
+ const_node_to_tensor [node ] = cast (
98
+ torch .Tensor , get_lifted_tensor_constant (exported_program , node )
99
+ )
100
+ return const_node_to_tensor
50
101
51
- has_cond = [
52
- node
53
- for node in exported_program .graph .nodes
54
- if node .target == torch .ops .higher_order .cond
55
- ]
56
- if len (has_cond ) > 0 :
57
- raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
58
102
103
+ def get_propagated_const_tensor_dict (
104
+ exported_program : ExportedProgram ,
105
+ custom_skip_targets : Optional [set [EdgeOpOverload ]],
106
+ ) -> OrderedDict [torch .fx .Node , torch .Tensor ]:
107
+ """
108
+ Propagates constants and returns a dictionary of node->constant tensors.
109
+ """
110
+ # Initialize dict with all constant placeholders.
111
+ const_node_to_tensor = get_constant_placeholder_dict (exported_program )
112
+
113
+ all_skip_targets : set [EdgeOpOverload ] = set ()
114
+ # Default set of targets to skip.
115
+ all_skip_targets .update (_DEFAULT_SKIP_TARGETS )
116
+ if custom_skip_targets is not None :
117
+ all_skip_targets .update (custom_skip_targets )
118
+
119
+ for node in exported_program .graph .nodes :
120
+ if node .op != "call_function" or node .target in all_skip_targets :
121
+ continue
122
+
123
+ if not is_const (
124
+ node .args ,
125
+ exported_program ,
126
+ const_node_to_tensor ,
127
+ ):
128
+ continue
129
+
130
+ args_data , kwargs_data = pytree .tree_map (
131
+ lambda x : get_data (x , exported_program , const_node_to_tensor ),
132
+ (node .args , node .kwargs ),
133
+ )
134
+
135
+ # Execute the `node.target` and create a new propagated constant tensor.
136
+ prop_constant_tensor = node .target (* args_data , ** kwargs_data )
137
+ const_node_to_tensor [node ] = prop_constant_tensor
138
+
139
+ return const_node_to_tensor
140
+
141
+
142
+ def get_first_user_input (exported_program : ExportedProgram ) -> torch .fx .Node :
143
+ """Returns the first user input node in the graph."""
59
144
first_user_input = None
60
145
for node in exported_program .graph .nodes :
61
146
if (
@@ -64,11 +149,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
64
149
):
65
150
first_user_input = node
66
151
break
152
+ return first_user_input
153
+
154
+
155
+ def replace_with_constant_node (
156
+ node : torch .fx .Node ,
157
+ prop_constant_tensor : torch .Tensor ,
158
+ first_user_input : torch .fx .Node ,
159
+ fake_mode ,
160
+ exported_program : ExportedProgram ,
161
+ ) -> tuple [torch .fx .Node , str ]:
162
+ # Add `prop_constant_tensor` to program.state_dict.
163
+ prop_constant_tensor_fqn = f"_prop_tensor_constant{ len (exported_program .constants )} "
164
+ exported_program .constants [prop_constant_tensor_fqn ] = prop_constant_tensor
165
+
166
+ # Insert a new placeholder node for the propagated constant tensor.
167
+ with exported_program .graph .inserting_before (first_user_input ):
168
+ const_placeholder_node = exported_program .graph .placeholder (
169
+ prop_constant_tensor_fqn
170
+ )
171
+
172
+ # Update the meta data of the new placeholder (buffer) node.
173
+ for k , v in node .meta .items ():
174
+ const_placeholder_node .meta [k ] = v
175
+ const_placeholder_node .meta ["val" ] = fake_mode .from_tensor (
176
+ prop_constant_tensor , static_shapes = True
177
+ )
178
+ const_placeholder_node .meta ["val" ].constant = prop_constant_tensor
179
+
180
+ # Replace the original node with the new constant node.
181
+ node .replace_all_uses_with (const_placeholder_node )
182
+ exported_program .graph .erase_node (node )
183
+
184
+ return const_placeholder_node , prop_constant_tensor_fqn
67
185
68
- buffers = exported_program .graph_signature .buffers
69
- prop_constant_data = []
70
- const_data_to_be_removed = set ()
71
186
187
+ def get_fake_mode (exported_program : ExportedProgram ):
72
188
fake_mode = detect_fake_mode (
73
189
tuple (
74
190
node .meta ["val" ]
@@ -77,57 +193,115 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
77
193
)
78
194
)
79
195
assert fake_mode is not None
196
+ return fake_mode
80
197
198
+
199
+ def erase_constant_node (
200
+ exported_program : ExportedProgram ,
201
+ node : torch .fx .Node ,
202
+ ):
203
+ # Remove from graph.
204
+ exported_program .graph .erase_node (node )
205
+
206
+ # Remove corresponding tensor from param/constants dict.
207
+ signature = exported_program .graph_signature
208
+ if name := signature .inputs_to_parameters .pop (node .name , None ):
209
+ exported_program .state_dict .pop (name , None )
210
+ elif name := signature .inputs_to_lifted_tensor_constants .pop (node .name , None ):
211
+ exported_program .constants .pop (name , None )
212
+ elif name := signature .inputs_to_buffers .pop (node .name , None ):
213
+ exported_program .constants .pop (name , None )
214
+ exported_program .state_dict .pop (name , None )
215
+
216
+
217
+ def create_constant_nodes_and_return_specs (
218
+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
219
+ exported_program : ExportedProgram ,
220
+ ) -> dict [str , InputSpec ]:
221
+ """
222
+ Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
223
+ """
224
+ name_to_spec_dict : dict [str , InputSpec ] = {}
225
+
226
+ fake_mode = get_fake_mode (exported_program )
227
+ first_user_input = get_first_user_input (exported_program )
228
+
229
+ # Iterate over nodes in reverse order.
230
+ for node , prop_constant_tensor in reversed (const_node_to_tensor .items ()):
231
+ if all (x in const_node_to_tensor for x in node .users ):
232
+ # All users of this constant node are also constant, so we don't need to create a new constant node.
233
+ erase_constant_node (exported_program , node )
234
+ continue
235
+
236
+ if node .op == "placeholder" :
237
+ continue
238
+
239
+ const_placeholder_node , prop_constant_tensor_fqn = replace_with_constant_node (
240
+ node , prop_constant_tensor , first_user_input , fake_mode , exported_program
241
+ )
242
+
243
+ # Create input spec for lifted constant.
244
+ name_to_spec_dict [const_placeholder_node .name ] = InputSpec (
245
+ kind = InputKind .CONSTANT_TENSOR ,
246
+ arg = TensorArgument (name = const_placeholder_node .name ),
247
+ target = prop_constant_tensor_fqn ,
248
+ persistent = True ,
249
+ )
250
+ return name_to_spec_dict
251
+
252
+
253
+ def constant_prop_pass (
254
+ exported_program : ExportedProgram ,
255
+ custom_skip_targets : Optional [set [EdgeOpOverload ]] = None ,
256
+ ) -> ExportedProgram :
257
+ """
258
+ This pass is for constant propagation for Exported Program with lifted parameters,
259
+ as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
260
+
261
+ Args:
262
+ exported_program: The ExportedProgram to perform constant propagation on.
263
+ custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation.
264
+
265
+ Returns:
266
+ The modified ExportedProgram with constant propagation applied.
267
+ """
268
+ if (
269
+ len ([node for node in exported_program .graph .nodes if node .op == "placeholder" ])
270
+ == 0
271
+ ):
272
+ return exported_program
273
+
274
+ has_control_flow = [
275
+ node
276
+ for node in exported_program .graph .nodes
277
+ if node .target == torch .ops .higher_order .cond
278
+ ]
279
+ if len (has_control_flow ) > 0 :
280
+ raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
281
+
282
+ const_node_to_tensor = get_propagated_const_tensor_dict (
283
+ exported_program , custom_skip_targets
284
+ )
285
+
286
+ # Get old input specs.
287
+ name_to_spec_dict = {
288
+ s .arg .name : s for s in exported_program .graph_signature .input_specs
289
+ }
290
+ # Add the new constants to input specs dict.
291
+ name_to_spec_dict .update (
292
+ create_constant_nodes_and_return_specs (const_node_to_tensor , exported_program )
293
+ )
294
+
295
+ # Generate new input spec.
296
+ new_input_specs = []
81
297
for node in exported_program .graph .nodes :
82
- if node .op == "call_function" :
83
- constant_data_name_list = [
84
- input_spec .target for input_spec in prop_constant_data
85
- ]
86
- if is_const (node .args , exported_program , constant_data_name_list ):
87
- args_data = [get_data (exported_program , arg ) for arg in node .args ]
88
- kwargs_data = node .kwargs
89
- const_data_to_be_removed .update (node .args )
90
- prop_constant_tensor = node .target (* args_data , ** kwargs_data )
91
- prop_constant_tensor_fqn = f"_prop_tensor_constant{ len (buffers )} "
92
-
93
- with exported_program .graph .inserting_before (first_user_input ):
94
- const_placeholder_node = exported_program .graph .placeholder (
95
- prop_constant_tensor_fqn
96
- )
97
- # Update the meta data of the new placeholder (buffer) node
98
- for k , v in node .meta .items ():
99
- const_placeholder_node .meta [k ] = v
100
- const_placeholder_node .meta ["val" ] = fake_mode .from_tensor (
101
- prop_constant_tensor , static_shapes = True
102
- )
103
- const_placeholder_node .meta ["val" ].constant = prop_constant_tensor
104
-
105
- node .replace_all_uses_with (const_placeholder_node )
106
- exported_program .graph .erase_node (node )
107
- prop_constant_node_input_spec = InputSpec (
108
- kind = InputKind .BUFFER ,
109
- arg = TensorArgument (name = const_placeholder_node .name ),
110
- target = prop_constant_tensor_fqn ,
111
- persistent = True ,
112
- )
113
- prop_constant_data .append (prop_constant_node_input_spec )
114
- buffers .append (prop_constant_tensor_fqn )
115
- exported_program .state_dict [prop_constant_tensor_fqn ] = (
116
- prop_constant_tensor
117
- )
118
- exported_program .graph_signature .input_specs .append (
119
- prop_constant_node_input_spec
120
- )
121
-
122
- # Remove the propogated buffer from the state dict
123
- for node in exported_program .graph .nodes :
124
- if (
125
- node .op == "placeholder"
126
- and node in const_data_to_be_removed
127
- and len (node .users ) == 0
128
- ):
129
- exported_program .state_dict .pop (node .name , None )
130
- exported_program .graph .erase_node (node )
298
+ if node .op != "placeholder" :
299
+ continue
300
+ new_input_specs .append (name_to_spec_dict [node .name ])
301
+ exported_program .graph_signature .input_specs = new_input_specs
131
302
303
+ # Cleanup the graph.
304
+ exported_program .graph .eliminate_dead_code ()
132
305
exported_program .graph_module .recompile ()
306
+
133
307
return exported_program
0 commit comments