Skip to content

Fix lift_constant_tensor_pass to make sure constants are not inserted in the state dict #2558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,7 @@ def program(self, emit_stacktrace: bool = False) -> Program:
for node in lowered_exported_program.graph.nodes
if (
node.op == "placeholder"
and node.name
not in lowered_exported_program.graph_signature.inputs_to_buffers
and node.name
not in lowered_exported_program.graph_signature.inputs_to_parameters
and node.name in lowered_exported_program.graph_signature.user_inputs
)
]

Expand All @@ -230,6 +227,8 @@ def program(self, emit_stacktrace: bool = False) -> Program:
node.name in lowered_exported_program.graph_signature.inputs_to_buffers
or node.name
in lowered_exported_program.graph_signature.inputs_to_parameters
or node.name
in lowered_exported_program.graph_signature.inputs_to_lifted_tensor_constants
):
lowered_exported_program.graph.erase_node(node)

Expand Down
11 changes: 10 additions & 1 deletion exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
# LICENSE file in the root directory of this source tree.

import torch
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
from torch._export.utils import (
get_buffer,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)
from torch._guards import detect_fake_mode
from torch.export import ExportedProgram
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
Expand All @@ -21,6 +27,7 @@ def is_const(arg, exported_program, const_data_list) -> bool:
elif (
is_param(exported_program, arg)
or is_buffer(exported_program, arg)
or is_lifted_tensor_constant(exported_program, arg)
or arg.name in const_data_list
):
return True
Expand All @@ -34,6 +41,8 @@ def get_data(exported_program, arg):
return get_param(exported_program, arg)
elif is_buffer(exported_program, arg):
return get_buffer(exported_program, arg)
elif arg.name in exported_program.constants:
return exported_program.constants[arg.name]
return None


Expand Down
2 changes: 2 additions & 0 deletions exir/passes/spec_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def update_placeholder_tensor_specs(
node.target in exported_program.graph_signature.inputs_to_buffers
and not _is_mutable_buffer(node, exported_program.graph_signature)
)
or node.target
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
):
spec.const = True

Expand Down
8 changes: 4 additions & 4 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def lift_constant_tensor_pass(ep):
if not isinstance(constant_tensor, torch.Tensor):
continue

constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"
constant_tensor_fqn = f"_lifted_tensor_constant{len(ep.constants)}"

with ep.graph.inserting_before(first_user_input):
# Insert the constant node before the first user input
Expand All @@ -209,14 +209,14 @@ def lift_constant_tensor_pass(ep):
# Add the constant as a buffer to the graph signature
lifted_constants.append(
InputSpec(
kind=InputKind.BUFFER,
kind=InputKind.CONSTANT_TENSOR,
arg=TensorArgument(name=const_placeholder_node.name),
target=constant_tensor_fqn,
persistent=True,
persistent=None,
)
)
buffers.append(constant_tensor_fqn)
ep.state_dict[constant_tensor_fqn] = constant_tensor
ep.constants[constant_tensor_fqn] = constant_tensor

new_input_specs = []
for s in graph_signature.input_specs:
Expand Down
2 changes: 1 addition & 1 deletion exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ python_unittest(
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:pass_base",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:backend_details",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
Expand Down Expand Up @@ -430,6 +429,7 @@ python_unittest(
],
deps = [
"//caffe2:torch",
"//executorch/exir:dim_order_utils",
"//executorch/exir:lib",
],
)
Expand Down
2 changes: 1 addition & 1 deletion exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor
FileCheck().check_not("_lifted_tensor_constant").check(
"_prop_tensor_constant1"
"_prop_tensor_constant0"
).check_not("executorch_exir_dialects_edge__ops_aten__to_copy_default").run(
new_ep.graph_module.code
)
Expand Down