Skip to content

Commit 5750ca8

Browse files
hsharma35facebook-github-bot
authored andcommitted
Extend constant prop pass to work with int/float/etc scalars and fix input specs. (#2950)
Summary: 1. Cleanup / Refactor constant prop pass. 2. Enable constant propagation for ops with constant scalar arguments -- int/float/dtype/bool/str. Nodes of type `Op(constant_tensor, some_int, some_float, some_dtype, ...)` can now be constant propagated. 3. Fix order of input spec to match the expected spec in `ExportGraphSignature` class. parameters->buffers->constants->user_inputs. Before this diff, input_specs for the newly added constant tensors were appended to graph_signature, which would cause failures. Differential Revision: D55891278
1 parent d761f99 commit 5750ca8

File tree

3 files changed

+360
-84
lines changed

3 files changed

+360
-84
lines changed

exir/passes/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ python_library(
9292
],
9393
deps = [
9494
"//caffe2:torch",
95+
"//executorch/exir/dialects:lib",
96+
"//executorch/exir/dialects/edge:lib",
9597
],
9698
)
9799

exir/passes/constant_prop_pass.py

Lines changed: 257 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,143 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from collections import OrderedDict
8+
from typing import cast, Mapping, Optional
9+
710
import torch
8-
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
13+
from torch._export.utils import (
14+
get_buffer,
15+
get_lifted_tensor_constant,
16+
get_param,
17+
is_buffer,
18+
is_lifted_tensor_constant,
19+
is_param,
20+
)
921
from torch._guards import detect_fake_mode
1022
from torch.export import ExportedProgram
1123
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
24+
from torch.utils import _pytree as pytree
25+
26+
27+
# Avoid propagating constants for `exir.ops.edge.aten.full.default`.
28+
# Propagating aten.full can significantly increase compiled model size.
29+
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}
1230

31+
_PRIMITIVE_TYPES = (
32+
float,
33+
int,
34+
bool,
35+
str,
36+
torch.Tensor,
37+
torch.device,
38+
torch.dtype,
39+
torch.layout,
40+
)
1341

14-
def is_const(arg, exported_program, const_data_list) -> bool:
42+
43+
def is_const(
44+
arg,
45+
exported_program: ExportedProgram,
46+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
47+
) -> bool:
1548
if isinstance(arg, (tuple, list)):
16-
return all(is_const(x, exported_program, const_data_list) for x in arg)
49+
return all(is_const(x, exported_program, const_node_to_tensor) for x in arg)
1750
elif isinstance(arg, dict):
18-
return all(is_const(x, exported_program, const_data_list) for x in arg.values())
19-
elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder":
51+
return all(
52+
is_const(x, exported_program, const_node_to_tensor) for x in arg.values()
53+
)
54+
elif isinstance(arg, _PRIMITIVE_TYPES):
55+
return True
56+
elif not isinstance(arg, torch.fx.Node):
2057
return False
21-
elif (
22-
is_param(exported_program, arg)
23-
or is_buffer(exported_program, arg)
24-
or arg.name in const_data_list
25-
):
58+
elif arg in const_node_to_tensor:
2659
return True
2760
return False
2861

2962

30-
def get_data(exported_program, arg):
63+
def get_data(
64+
arg,
65+
exported_program: ExportedProgram,
66+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
67+
):
3168
if isinstance(arg, (tuple, list)):
32-
return [get_data(exported_program, x) for x in arg]
33-
elif is_param(exported_program, arg):
34-
return get_param(exported_program, arg)
35-
elif is_buffer(exported_program, arg):
36-
return get_buffer(exported_program, arg)
69+
return [get_data(x, exported_program, const_node_to_tensor) for x in arg]
70+
elif isinstance(arg, _PRIMITIVE_TYPES):
71+
return arg
72+
elif arg in const_node_to_tensor:
73+
return const_node_to_tensor[arg]
3774
return None
3875

3976

40-
def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
77+
def get_constant_placeholder_dict(
78+
exported_program: ExportedProgram,
79+
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
4180
"""
42-
This pass is for constant propagation for Exported Program with lifted parameters,
43-
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
81+
Returns a dictionary of placeholder node -> constant tensor.
4482
"""
45-
if (
46-
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
47-
== 0
48-
):
49-
return exported_program
83+
const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
84+
for node in exported_program.graph.nodes:
85+
if node.op != "placeholder":
86+
continue
87+
88+
if is_param(exported_program, node):
89+
const_node_to_tensor[node] = cast(
90+
torch.Tensor, get_param(exported_program, node)
91+
)
92+
elif is_buffer(exported_program, node):
93+
const_node_to_tensor[node] = cast(
94+
torch.Tensor, get_buffer(exported_program, node)
95+
)
96+
elif is_lifted_tensor_constant(exported_program, node):
97+
const_node_to_tensor[node] = cast(
98+
torch.Tensor, get_lifted_tensor_constant(exported_program, node)
99+
)
100+
return const_node_to_tensor
50101

51-
has_cond = [
52-
node
53-
for node in exported_program.graph.nodes
54-
if node.target == torch.ops.higher_order.cond
55-
]
56-
if len(has_cond) > 0:
57-
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
58102

103+
def get_propagated_const_tensor_dict(
104+
exported_program: ExportedProgram,
105+
custom_skip_targets: Optional[set[EdgeOpOverload]],
106+
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
107+
"""
108+
Propagates constants and returns a dictionary of node->constant tensors.
109+
"""
110+
# Initialize dict with all constant placeholders.
111+
const_node_to_tensor = get_constant_placeholder_dict(exported_program)
112+
113+
all_skip_targets: set[EdgeOpOverload] = set()
114+
# Default set of targets to skip.
115+
all_skip_targets.update(_DEFAULT_SKIP_TARGETS)
116+
if custom_skip_targets is not None:
117+
all_skip_targets.update(custom_skip_targets)
118+
119+
for node in exported_program.graph.nodes:
120+
if node.op != "call_function" or node.target in all_skip_targets:
121+
continue
122+
123+
if not is_const(
124+
node.args,
125+
exported_program,
126+
const_node_to_tensor,
127+
):
128+
continue
129+
130+
args_data, kwargs_data = pytree.tree_map(
131+
lambda x: get_data(x, exported_program, const_node_to_tensor),
132+
(node.args, node.kwargs),
133+
)
134+
135+
# Execute the `node.target` and create a new propagated constant tensor.
136+
prop_constant_tensor = node.target(*args_data, **kwargs_data)
137+
const_node_to_tensor[node] = prop_constant_tensor
138+
139+
return const_node_to_tensor
140+
141+
142+
def get_first_user_input(exported_program: ExportedProgram) -> torch.fx.Node:
143+
"""Returns the first user input node in the graph."""
59144
first_user_input = None
60145
for node in exported_program.graph.nodes:
61146
if (
@@ -64,11 +149,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
64149
):
65150
first_user_input = node
66151
break
152+
return first_user_input
153+
154+
155+
def replace_with_constant_node(
156+
node: torch.fx.Node,
157+
prop_constant_tensor: torch.Tensor,
158+
first_user_input: torch.fx.Node,
159+
fake_mode,
160+
exported_program: ExportedProgram,
161+
) -> tuple[torch.fx.Node, str]:
162+
# Add `prop_constant_tensor` to program.state_dict.
163+
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}"
164+
exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor
165+
166+
# Insert a new placeholder node for the propagated constant tensor.
167+
with exported_program.graph.inserting_before(first_user_input):
168+
const_placeholder_node = exported_program.graph.placeholder(
169+
prop_constant_tensor_fqn
170+
)
171+
172+
# Update the meta data of the new placeholder (buffer) node.
173+
for k, v in node.meta.items():
174+
const_placeholder_node.meta[k] = v
175+
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
176+
prop_constant_tensor, static_shapes=True
177+
)
178+
const_placeholder_node.meta["val"].constant = prop_constant_tensor
179+
180+
# Replace the original node with the new constant node.
181+
node.replace_all_uses_with(const_placeholder_node)
182+
exported_program.graph.erase_node(node)
183+
184+
return const_placeholder_node, prop_constant_tensor_fqn
67185

68-
buffers = exported_program.graph_signature.buffers
69-
prop_constant_data = []
70-
const_data_to_be_removed = set()
71186

187+
def get_fake_mode(exported_program: ExportedProgram):
72188
fake_mode = detect_fake_mode(
73189
tuple(
74190
node.meta["val"]
@@ -77,57 +193,115 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
77193
)
78194
)
79195
assert fake_mode is not None
196+
return fake_mode
80197

198+
199+
def erase_constant_node(
200+
exported_program: ExportedProgram,
201+
node: torch.fx.Node,
202+
):
203+
# Remove from graph.
204+
exported_program.graph.erase_node(node)
205+
206+
# Remove corresponding tensor from param/constants dict.
207+
signature = exported_program.graph_signature
208+
if name := signature.inputs_to_parameters.pop(node.name, None):
209+
exported_program.state_dict.pop(name, None)
210+
elif name := signature.inputs_to_lifted_tensor_constants.pop(node.name, None):
211+
exported_program.constants.pop(name, None)
212+
elif name := signature.inputs_to_buffers.pop(node.name, None):
213+
exported_program.constants.pop(name, None)
214+
exported_program.state_dict.pop(name, None)
215+
216+
217+
def create_constant_nodes_and_return_specs(
218+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
219+
exported_program: ExportedProgram,
220+
) -> dict[str, InputSpec]:
221+
"""
222+
Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
223+
"""
224+
name_to_spec_dict: dict[str, InputSpec] = {}
225+
226+
fake_mode = get_fake_mode(exported_program)
227+
first_user_input = get_first_user_input(exported_program)
228+
229+
# Iterate over nodes in reverse order.
230+
for node, prop_constant_tensor in reversed(const_node_to_tensor.items()):
231+
if all(x in const_node_to_tensor for x in node.users):
232+
# All users of this constant node are also constant, so we don't need to create a new constant node.
233+
erase_constant_node(exported_program, node)
234+
continue
235+
236+
if node.op == "placeholder":
237+
continue
238+
239+
const_placeholder_node, prop_constant_tensor_fqn = replace_with_constant_node(
240+
node, prop_constant_tensor, first_user_input, fake_mode, exported_program
241+
)
242+
243+
# Create input spec for lifted constant.
244+
name_to_spec_dict[const_placeholder_node.name] = InputSpec(
245+
kind=InputKind.CONSTANT_TENSOR,
246+
arg=TensorArgument(name=const_placeholder_node.name),
247+
target=prop_constant_tensor_fqn,
248+
persistent=True,
249+
)
250+
return name_to_spec_dict
251+
252+
253+
def constant_prop_pass(
254+
exported_program: ExportedProgram,
255+
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
256+
) -> ExportedProgram:
257+
"""
258+
This pass is for constant propagation for Exported Program with lifted parameters,
259+
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
260+
261+
Args:
262+
exported_program: The ExportedProgram to perform constant propagation on.
263+
custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation.
264+
265+
Returns:
266+
The modified ExportedProgram with constant propagation applied.
267+
"""
268+
if (
269+
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
270+
== 0
271+
):
272+
return exported_program
273+
274+
has_control_flow = [
275+
node
276+
for node in exported_program.graph.nodes
277+
if node.target == torch.ops.higher_order.cond
278+
]
279+
if len(has_control_flow) > 0:
280+
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
281+
282+
const_node_to_tensor = get_propagated_const_tensor_dict(
283+
exported_program, custom_skip_targets
284+
)
285+
286+
# Get old input specs.
287+
name_to_spec_dict = {
288+
s.arg.name: s for s in exported_program.graph_signature.input_specs
289+
}
290+
# Add the new constants to input specs dict.
291+
name_to_spec_dict.update(
292+
create_constant_nodes_and_return_specs(const_node_to_tensor, exported_program)
293+
)
294+
295+
# Generate new input spec.
296+
new_input_specs = []
81297
for node in exported_program.graph.nodes:
82-
if node.op == "call_function":
83-
constant_data_name_list = [
84-
input_spec.target for input_spec in prop_constant_data
85-
]
86-
if is_const(node.args, exported_program, constant_data_name_list):
87-
args_data = [get_data(exported_program, arg) for arg in node.args]
88-
kwargs_data = node.kwargs
89-
const_data_to_be_removed.update(node.args)
90-
prop_constant_tensor = node.target(*args_data, **kwargs_data)
91-
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(buffers)}"
92-
93-
with exported_program.graph.inserting_before(first_user_input):
94-
const_placeholder_node = exported_program.graph.placeholder(
95-
prop_constant_tensor_fqn
96-
)
97-
# Update the meta data of the new placeholder (buffer) node
98-
for k, v in node.meta.items():
99-
const_placeholder_node.meta[k] = v
100-
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
101-
prop_constant_tensor, static_shapes=True
102-
)
103-
const_placeholder_node.meta["val"].constant = prop_constant_tensor
104-
105-
node.replace_all_uses_with(const_placeholder_node)
106-
exported_program.graph.erase_node(node)
107-
prop_constant_node_input_spec = InputSpec(
108-
kind=InputKind.BUFFER,
109-
arg=TensorArgument(name=const_placeholder_node.name),
110-
target=prop_constant_tensor_fqn,
111-
persistent=True,
112-
)
113-
prop_constant_data.append(prop_constant_node_input_spec)
114-
buffers.append(prop_constant_tensor_fqn)
115-
exported_program.state_dict[prop_constant_tensor_fqn] = (
116-
prop_constant_tensor
117-
)
118-
exported_program.graph_signature.input_specs.append(
119-
prop_constant_node_input_spec
120-
)
121-
122-
# Remove the propogated buffer from the state dict
123-
for node in exported_program.graph.nodes:
124-
if (
125-
node.op == "placeholder"
126-
and node in const_data_to_be_removed
127-
and len(node.users) == 0
128-
):
129-
exported_program.state_dict.pop(node.name, None)
130-
exported_program.graph.erase_node(node)
298+
if node.op != "placeholder":
299+
continue
300+
new_input_specs.append(name_to_spec_dict[node.name])
301+
exported_program.graph_signature.input_specs = new_input_specs
131302

303+
# Cleanup the graph.
304+
exported_program.graph.eliminate_dead_code()
132305
exported_program.graph_module.recompile()
306+
133307
return exported_program

0 commit comments

Comments
 (0)