Skip to content

Commit 1dbb47e

Browse files
hsharma35facebook-github-bot
authored andcommitted
Extend constant prop pass to work with int/float/etc scalars and fix input specs.
Summary: 1. Cleanup / Refactor constant prop pass. 2. Enable constant propagation for ops with constant scalar arguments -- int/float/dtype/bool/str. Nodes of type `Op(constant_tensor, some_int, some_float, some_dtype, ...)` can now be constant propagated. 3. Fix order of input spec to match the expected spec in `ExportGraphSignature` class. parameters->buffers->constants->user_inputs. Before this diff, input_specs for the newly added constant tensors were appended to graph_signature, which would cause failures whe. Differential Revision: D55891278
1 parent c4ac14c commit 1dbb47e

File tree

3 files changed

+313
-84
lines changed

3 files changed

+313
-84
lines changed

exir/passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ python_library(
9292
],
9393
deps = [
9494
"//caffe2:torch",
95+
"//executorch/exir/dialects:lib",
9596
],
9697
)
9798

exir/passes/constant_prop_pass.py

Lines changed: 238 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,138 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from collections import OrderedDict
8+
from typing import cast, Mapping
9+
710
import torch
8-
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from torch._export.utils import (
13+
get_buffer,
14+
get_lifted_tensor_constant,
15+
get_param,
16+
is_buffer,
17+
is_lifted_tensor_constant,
18+
is_param,
19+
)
920
from torch._guards import detect_fake_mode
1021
from torch.export import ExportedProgram
1122
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
23+
from torch.utils import _pytree as pytree
24+
25+
26+
# Avoid propagating constants for `exir.ops.edge.aten.full.default`.
27+
# Propagating aten.full can significantly increase compiled model size.
28+
SKIP_TARGETS = {exir_ops.edge.aten.full.default}
1229

1330

14-
def is_const(arg, exported_program, const_data_list) -> bool:
31+
def is_const(
32+
arg,
33+
exported_program: ExportedProgram,
34+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
35+
allow_propagation_of_scalars: bool,
36+
) -> bool:
1537
if isinstance(arg, (tuple, list)):
16-
return all(is_const(x, exported_program, const_data_list) for x in arg)
38+
return all(
39+
is_const(
40+
x, exported_program, const_node_to_tensor, allow_propagation_of_scalars
41+
)
42+
for x in arg
43+
)
1744
elif isinstance(arg, dict):
18-
return all(is_const(x, exported_program, const_data_list) for x in arg.values())
19-
elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder":
45+
return all(
46+
is_const(
47+
x, exported_program, const_node_to_tensor, allow_propagation_of_scalars
48+
)
49+
for x in arg.values()
50+
)
51+
elif isinstance(arg, (int, float, bool, str, torch.dtype)):
52+
return allow_propagation_of_scalars
53+
elif not isinstance(arg, torch.fx.Node):
2054
return False
21-
elif (
22-
is_param(exported_program, arg)
23-
or is_buffer(exported_program, arg)
24-
or arg.name in const_data_list
25-
):
55+
elif arg in const_node_to_tensor:
2656
return True
2757
return False
2858

2959

30-
def get_data(exported_program, arg):
60+
def get_data(
61+
arg,
62+
exported_program: ExportedProgram,
63+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
64+
):
3165
if isinstance(arg, (tuple, list)):
32-
return [get_data(exported_program, x) for x in arg]
33-
elif is_param(exported_program, arg):
34-
return get_param(exported_program, arg)
35-
elif is_buffer(exported_program, arg):
36-
return get_buffer(exported_program, arg)
66+
return [get_data(x, exported_program, const_node_to_tensor) for x in arg]
67+
elif isinstance(arg, (int, float, bool, str, torch.dtype)):
68+
return arg
69+
elif arg in const_node_to_tensor:
70+
return const_node_to_tensor[arg]
3771
return None
3872

3973

40-
def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
74+
def get_constant_placeholder_dict(
75+
exported_program: ExportedProgram,
76+
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
4177
"""
42-
This pass is for constant propagation for Exported Program with lifted parameters,
43-
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
78+
Returns a dictionary of placeholder node -> constant tensor.
4479
"""
45-
if (
46-
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
47-
== 0
48-
):
49-
return exported_program
80+
const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
81+
for node in exported_program.graph.nodes:
82+
if node.op != "placeholder":
83+
continue
84+
85+
if is_param(exported_program, node):
86+
const_node_to_tensor[node] = cast(
87+
torch.Tensor, get_param(exported_program, node)
88+
)
89+
elif is_buffer(exported_program, node):
90+
const_node_to_tensor[node] = cast(
91+
torch.Tensor, get_buffer(exported_program, node)
92+
)
93+
elif is_lifted_tensor_constant(exported_program, node):
94+
const_node_to_tensor[node] = cast(
95+
torch.Tensor, get_lifted_tensor_constant(exported_program, node)
96+
)
97+
return const_node_to_tensor
5098

51-
has_cond = [
52-
node
53-
for node in exported_program.graph.nodes
54-
if node.target == torch.ops.higher_order.cond
55-
]
56-
if len(has_cond) > 0:
57-
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
5899

100+
def get_propagated_const_tensor_dict(
101+
exported_program: ExportedProgram,
102+
allow_propagation_of_scalars: bool,
103+
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
104+
"""
105+
Propagates constants and returns a dictionary of node->constant tensors.
106+
"""
107+
# Initialize dict with all constant placeholders.
108+
const_node_to_tensor = get_constant_placeholder_dict(exported_program)
109+
110+
for node in exported_program.graph.nodes:
111+
if node.op == "placeholder":
112+
continue
113+
114+
if node.op != "call_function" or node.target in SKIP_TARGETS:
115+
continue
116+
117+
if not is_const(
118+
node.args,
119+
exported_program,
120+
const_node_to_tensor,
121+
allow_propagation_of_scalars=allow_propagation_of_scalars,
122+
):
123+
continue
124+
125+
args_data, kwargs_data = pytree.tree_map(
126+
lambda x: get_data(x, exported_program, const_node_to_tensor),
127+
(node.args, node.kwargs),
128+
)
129+
130+
# Execute the `node.target` and create a new propagated constant tensor.
131+
prop_constant_tensor = node.target(*args_data, **kwargs_data)
132+
const_node_to_tensor[node] = prop_constant_tensor
133+
134+
return const_node_to_tensor
135+
136+
137+
def get_first_user_input(exported_program: ExportedProgram) -> torch.fx.Node:
138+
"""Returns the first user input node in the graph."""
59139
first_user_input = None
60140
for node in exported_program.graph.nodes:
61141
if (
@@ -64,11 +144,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
64144
):
65145
first_user_input = node
66146
break
147+
return first_user_input
148+
149+
150+
def replace_with_constant_node(
151+
node: torch.fx.Node,
152+
prop_constant_tensor: torch.Tensor,
153+
first_user_input: torch.fx.Node,
154+
fake_mode,
155+
exported_program: ExportedProgram,
156+
) -> tuple[torch.fx.Node, str]:
157+
# Add `prop_constant_tensor` to program.constants.
158+
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}"
159+
exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor
160+
161+
# Insert a new placeholder node for the propagated constant tensor.
162+
with exported_program.graph.inserting_before(first_user_input):
163+
const_placeholder_node = exported_program.graph.placeholder(
164+
prop_constant_tensor_fqn
165+
)
166+
167+
# Update the meta data of the new placeholder (buffer) node.
168+
for k, v in node.meta.items():
169+
const_placeholder_node.meta[k] = v
170+
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
171+
prop_constant_tensor, static_shapes=True
172+
)
173+
const_placeholder_node.meta["val"].constant = prop_constant_tensor
174+
175+
# Replace the original node with the new constant node.
176+
node.replace_all_uses_with(const_placeholder_node)
177+
exported_program.graph.erase_node(node)
67178

68-
buffers = exported_program.graph_signature.buffers
69-
prop_constant_data = []
70-
const_data_to_be_removed = set()
179+
return const_placeholder_node, prop_constant_tensor_fqn
71180

181+
182+
def get_fake_mode(exported_program: ExportedProgram):
72183
fake_mode = detect_fake_mode(
73184
tuple(
74185
node.meta["val"]
@@ -77,57 +188,101 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
77188
)
78189
)
79190
assert fake_mode is not None
191+
return fake_mode
80192

193+
194+
def erase_constant_node(
195+
exported_program: ExportedProgram,
196+
node: torch.fx.Node,
197+
):
198+
# Remove from graph.
199+
exported_program.graph.erase_node(node)
200+
201+
# Remove corresponding tensor from param/constants dict.
202+
signature = exported_program.graph_signature
203+
if name := signature.inputs_to_parameters.pop(node.name, None):
204+
exported_program.state_dict.pop(name, None)
205+
elif name := signature.inputs_to_lifted_tensor_constants.pop(node.name, None):
206+
exported_program.constants.pop(name, None)
207+
208+
209+
def create_constant_nodes_and_return_specs(
210+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
211+
exported_program: ExportedProgram,
212+
) -> dict[str, InputSpec]:
213+
"""
214+
Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
215+
"""
216+
name_to_spec_dict: dict[str, InputSpec] = {}
217+
218+
fake_mode = get_fake_mode(exported_program)
219+
first_user_input = get_first_user_input(exported_program)
220+
221+
# Iterate over nodes in reverse order.
222+
for node, prop_constant_tensor in reversed(const_node_to_tensor.items()):
223+
if all(x in const_node_to_tensor for x in node.users):
224+
# All users of this constant node are also constant, so we don't need to create a new constant node.
225+
erase_constant_node(exported_program, node)
226+
continue
227+
228+
const_placeholder_node, prop_constant_tensor_fqn = replace_with_constant_node(
229+
node, prop_constant_tensor, first_user_input, fake_mode, exported_program
230+
)
231+
232+
# Create input spec for lifted constant.
233+
name_to_spec_dict[const_placeholder_node.name] = InputSpec(
234+
kind=InputKind.CONSTANT_TENSOR,
235+
arg=TensorArgument(name=const_placeholder_node.name),
236+
target=prop_constant_tensor_fqn,
237+
persistent=True,
238+
)
239+
return name_to_spec_dict
240+
241+
242+
def constant_prop_pass(
243+
exported_program: ExportedProgram, allow_propagation_of_scalars: bool = False
244+
) -> ExportedProgram:
245+
"""
246+
This pass is for constant propagation for Exported Program with lifted parameters,
247+
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
248+
"""
249+
if (
250+
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
251+
== 0
252+
):
253+
return exported_program
254+
255+
has_control_flow = [
256+
node
257+
for node in exported_program.graph.nodes
258+
if node.target == torch.ops.higher_order.cond
259+
]
260+
if len(has_control_flow) > 0:
261+
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
262+
263+
const_node_to_tensor = get_propagated_const_tensor_dict(
264+
exported_program, allow_propagation_of_scalars=allow_propagation_of_scalars
265+
)
266+
267+
# Get old input specs.
268+
name_to_spec_dict = {
269+
s.arg.name: s for s in exported_program.graph_signature.input_specs
270+
}
271+
# Add the new constants to input specs dict.
272+
name_to_spec_dict.update(
273+
create_constant_nodes_and_return_specs(const_node_to_tensor, exported_program)
274+
)
275+
276+
# Generate new input spec.
277+
new_input_specs = []
81278
for node in exported_program.graph.nodes:
82-
if node.op == "call_function":
83-
constant_data_name_list = [
84-
input_spec.target for input_spec in prop_constant_data
85-
]
86-
if is_const(node.args, exported_program, constant_data_name_list):
87-
args_data = [get_data(exported_program, arg) for arg in node.args]
88-
kwargs_data = node.kwargs
89-
const_data_to_be_removed.update(node.args)
90-
prop_constant_tensor = node.target(*args_data, **kwargs_data)
91-
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(buffers)}"
92-
93-
with exported_program.graph.inserting_before(first_user_input):
94-
const_placeholder_node = exported_program.graph.placeholder(
95-
prop_constant_tensor_fqn
96-
)
97-
# Update the meta data of the new placeholder (buffer) node
98-
for k, v in node.meta.items():
99-
const_placeholder_node.meta[k] = v
100-
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
101-
prop_constant_tensor, static_shapes=True
102-
)
103-
const_placeholder_node.meta["val"].constant = prop_constant_tensor
104-
105-
node.replace_all_uses_with(const_placeholder_node)
106-
exported_program.graph.erase_node(node)
107-
prop_constant_node_input_spec = InputSpec(
108-
kind=InputKind.BUFFER,
109-
arg=TensorArgument(name=const_placeholder_node.name),
110-
target=prop_constant_tensor_fqn,
111-
persistent=True,
112-
)
113-
prop_constant_data.append(prop_constant_node_input_spec)
114-
buffers.append(prop_constant_tensor_fqn)
115-
exported_program.state_dict[prop_constant_tensor_fqn] = (
116-
prop_constant_tensor
117-
)
118-
exported_program.graph_signature.input_specs.append(
119-
prop_constant_node_input_spec
120-
)
121-
122-
# Remove the propogated buffer from the state dict
123-
for node in exported_program.graph.nodes:
124-
if (
125-
node.op == "placeholder"
126-
and node in const_data_to_be_removed
127-
and len(node.users) == 0
128-
):
129-
exported_program.state_dict.pop(node.name, None)
130-
exported_program.graph.erase_node(node)
279+
if node.op != "placeholder":
280+
continue
281+
new_input_specs.append(name_to_spec_dict[node.name])
282+
exported_program.graph_signature.input_specs = new_input_specs
131283

284+
# Cleanup the graph.
285+
exported_program.graph.eliminate_dead_code()
132286
exported_program.graph_module.recompile()
287+
133288
return exported_program

0 commit comments

Comments
 (0)