Skip to content

Fixes for constant_prop_pass #2967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,36 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, List, Optional

import torch
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
from torch._guards import detect_fake_mode
from torch.export import ExportedProgram
from torch.export.exported_program import InputKind, InputSpec, TensorArgument


def is_const(arg, exported_program, const_data_list) -> bool:
_PRIMITIVE_TYPES = (
float,
int,
bool,
str,
torch.Tensor,
torch.device,
torch.dtype,
torch.layout,
)


def is_const(
arg: object, exported_program: ExportedProgram, const_data_list: List[str]
) -> bool:
if isinstance(arg, (tuple, list)):
return all(is_const(x, exported_program, const_data_list) for x in arg)
elif isinstance(arg, dict):
return all(is_const(x, exported_program, const_data_list) for x in arg.values())
elif isinstance(arg, _PRIMITIVE_TYPES):
return True
elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder":
return False
elif (
Expand All @@ -27,17 +45,22 @@ def is_const(arg, exported_program, const_data_list) -> bool:
return False


def get_data(exported_program, arg):
def get_data(exported_program: ExportedProgram, arg):
if isinstance(arg, (tuple, list)):
return [get_data(exported_program, x) for x in arg]
elif isinstance(arg, _PRIMITIVE_TYPES):
return arg
elif is_param(exported_program, arg):
return get_param(exported_program, arg)
elif is_buffer(exported_program, arg):
return get_buffer(exported_program, arg)
return None


def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
def constant_prop_pass(
exported_program: ExportedProgram,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> ExportedProgram:
"""
This pass is for constant propagation for Exported Program with lifted parameters,
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
Expand All @@ -56,12 +79,14 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
if len(has_cond) > 0:
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")

first_user_input_idx = -1
first_user_input = None
for node in exported_program.graph.nodes:
for i, node in enumerate(exported_program.graph.nodes):
if (
node.op == "placeholder"
and node.name in exported_program.graph_signature.user_inputs
):
first_user_input_idx = i
first_user_input = node
break

Expand All @@ -79,6 +104,9 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
assert fake_mode is not None

for node in exported_program.graph.nodes:
if skip_folding_node_fn is not None and skip_folding_node_fn(node):
# Do not process this node if we were told to skip it.
continue
if node.op == "call_function":
constant_data_name_list = [
input_spec.target for input_spec in prop_constant_data
Expand Down Expand Up @@ -115,9 +143,11 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
exported_program.state_dict[prop_constant_tensor_fqn] = (
prop_constant_tensor
)
exported_program.graph_signature.input_specs.append(
prop_constant_node_input_spec
# Insert new buffers before the first user input.
exported_program.graph_signature.input_specs.insert(
first_user_input_idx, prop_constant_node_input_spec
)
first_user_input_idx += 1

# Remove the propogated buffer from the state dict
for node in exported_program.graph.nodes:
Expand All @@ -128,6 +158,16 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
):
exported_program.state_dict.pop(node.name, None)
exported_program.graph.erase_node(node)
# Delete the input spec for this deleted buffer.
to_erase_idx = []
for i, spec in enumerate(exported_program.graph_signature.input_specs):
if spec.arg.name == node.name:
to_erase_idx.append(i)
assert (
len(to_erase_idx) == 1
), f"Should only delete one spec per node, but deleting multiple: {to_erase_idx} {exported_program.graph_signature.input_specs}"
for i in reversed(to_erase_idx):
exported_program.graph_signature.input_specs.pop(i)

exported_program.graph_module.recompile()
return exported_program