Skip to content

Commit 887613c

Browse files
tarun292facebook-github-bot
authored andcommitted
Fix lift_constant_tensor_pass to make sure constants are not inserted in the state dict
Differential Revision: D55175415
1 parent a41ac1c commit 887613c

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
8+
from torch._export.utils import (
9+
get_buffer,
10+
get_param,
11+
is_buffer,
12+
is_lifted_tensor_constant,
13+
is_param,
14+
)
915
from torch._guards import detect_fake_mode
1016
from torch.export import ExportedProgram
1117
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
@@ -21,6 +27,7 @@ def is_const(arg, exported_program, const_data_list) -> bool:
2127
elif (
2228
is_param(exported_program, arg)
2329
or is_buffer(exported_program, arg)
30+
or is_lifted_tensor_constant(exported_program, arg)
2431
or arg.name in const_data_list
2532
):
2633
return True
@@ -34,6 +41,8 @@ def get_data(exported_program, arg):
3441
return get_param(exported_program, arg)
3542
elif is_buffer(exported_program, arg):
3643
return get_buffer(exported_program, arg)
44+
elif arg.name in exported_program.constants:
45+
return exported_program.constants[arg.name]
3746
return None
3847

3948

exir/program/_program.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def lift_constant_tensor_pass(ep):
186186
if not isinstance(constant_tensor, torch.Tensor):
187187
continue
188188

189-
constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"
189+
constant_tensor_fqn = f"_lifted_tensor_constant{len(ep.constants)}"
190190

191191
with ep.graph.inserting_before(first_user_input):
192192
# Insert the constant node before the first user input
@@ -206,14 +206,14 @@ def lift_constant_tensor_pass(ep):
206206
# Add the constant as a buffer to the graph signature
207207
lifted_constants.append(
208208
InputSpec(
209-
kind=InputKind.BUFFER,
209+
kind=InputKind.CONSTANT_TENSOR,
210210
arg=TensorArgument(name=const_placeholder_node.name),
211211
target=constant_tensor_fqn,
212-
persistent=True,
212+
persistent=None,
213213
)
214214
)
215215
buffers.append(constant_tensor_fqn)
216-
ep.state_dict[constant_tensor_fqn] = constant_tensor
216+
ep.constants[constant_tensor_fqn] = constant_tensor
217217

218218
new_input_specs = []
219219
for s in graph_signature.input_specs:

exir/tests/test_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
11311131

11321132
# Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor
11331133
FileCheck().check_not("_lifted_tensor_constant").check(
1134-
"_prop_tensor_constant1"
1134+
"_prop_tensor_constant0"
11351135
).check_not("executorch_exir_dialects_edge__ops_aten__to_copy_default").run(
11361136
new_ep.graph_module.code
11371137
)

0 commit comments

Comments
 (0)