Skip to content

Commit 194b751

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Fixes for constant_prop_pass (#2967)
Summary: The constant_prop_pass didn't properly propagate constants when there were simple primitives in the argument set. Extend it to see floats, ints, strings, etc. as constant functions. This allows this pass to fold additional things like quantize functions on weights. Sometimes users don't want that, so allow them to use a lambda to skip some nodes. Differential Revision: D55942686
1 parent a983ebc commit 194b751

File tree

1 file changed

+46
-6
lines changed

1 file changed

+46
-6
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,36 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Callable, List, Optional
8+
79
import torch
810
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
911
from torch._guards import detect_fake_mode
1012
from torch.export import ExportedProgram
1113
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
1214

1315

14-
def is_const(arg, exported_program, const_data_list) -> bool:
16+
_PRIMITIVE_TYPES = (
17+
float,
18+
int,
19+
bool,
20+
str,
21+
torch.Tensor,
22+
torch.device,
23+
torch.dtype,
24+
torch.layout,
25+
)
26+
27+
28+
def is_const(
29+
arg: object, exported_program: ExportedProgram, const_data_list: List[str]
30+
) -> bool:
1531
if isinstance(arg, (tuple, list)):
1632
return all(is_const(x, exported_program, const_data_list) for x in arg)
1733
elif isinstance(arg, dict):
1834
return all(is_const(x, exported_program, const_data_list) for x in arg.values())
35+
elif isinstance(arg, _PRIMITIVE_TYPES):
36+
return True
1937
elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder":
2038
return False
2139
elif (
@@ -27,17 +45,22 @@ def is_const(arg, exported_program, const_data_list) -> bool:
2745
return False
2846

2947

30-
def get_data(exported_program, arg):
48+
def get_data(exported_program: ExportedProgram, arg):
3149
if isinstance(arg, (tuple, list)):
3250
return [get_data(exported_program, x) for x in arg]
51+
elif isinstance(arg, _PRIMITIVE_TYPES):
52+
return arg
3353
elif is_param(exported_program, arg):
3454
return get_param(exported_program, arg)
3555
elif is_buffer(exported_program, arg):
3656
return get_buffer(exported_program, arg)
3757
return None
3858

3959

40-
def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
60+
def constant_prop_pass(
61+
exported_program: ExportedProgram,
62+
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
63+
) -> ExportedProgram:
4164
"""
4265
This pass is for constant propagation for Exported Program with lifted parameters,
4366
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
@@ -56,12 +79,14 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
5679
if len(has_cond) > 0:
5780
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
5881

82+
first_user_input_idx = -1
5983
first_user_input = None
60-
for node in exported_program.graph.nodes:
84+
for i, node in enumerate(exported_program.graph.nodes):
6185
if (
6286
node.op == "placeholder"
6387
and node.name in exported_program.graph_signature.user_inputs
6488
):
89+
first_user_input_idx = i
6590
first_user_input = node
6691
break
6792

@@ -79,6 +104,9 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
79104
assert fake_mode is not None
80105

81106
for node in exported_program.graph.nodes:
107+
if skip_folding_node_fn is not None and skip_folding_node_fn(node):
108+
# Do not process this node if we were told to skip it.
109+
continue
82110
if node.op == "call_function":
83111
constant_data_name_list = [
84112
input_spec.target for input_spec in prop_constant_data
@@ -115,9 +143,11 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
115143
exported_program.state_dict[prop_constant_tensor_fqn] = (
116144
prop_constant_tensor
117145
)
118-
exported_program.graph_signature.input_specs.append(
119-
prop_constant_node_input_spec
146+
# Insert new buffers before the first user input.
147+
exported_program.graph_signature.input_specs.insert(
148+
first_user_input_idx, prop_constant_node_input_spec
120149
)
150+
first_user_input_idx += 1
121151

122152
# Remove the propogated buffer from the state dict
123153
for node in exported_program.graph.nodes:
@@ -128,6 +158,16 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
128158
):
129159
exported_program.state_dict.pop(node.name, None)
130160
exported_program.graph.erase_node(node)
161+
# Delete the input spec for this deleted buffer.
162+
to_erase_idx = []
163+
for i, spec in enumerate(exported_program.graph_signature.input_specs):
164+
if spec.arg.name == node.name:
165+
to_erase_idx.append(i)
166+
assert (
167+
len(to_erase_idx) == 1
168+
), f"Should only delete one spec per node, but deleting multiple: {to_erase_idx} {exported_program.graph_signature.input_specs}"
169+
for i in reversed(to_erase_idx):
170+
exported_program.graph_signature.input_specs.pop(i)
131171

132172
exported_program.graph_module.recompile()
133173
return exported_program

0 commit comments

Comments
 (0)