Skip to content

Commit a624345

Browse files
tarun292facebook-github-bot
authored andcommitted
Fix lift_constant_tensor_pass to make sure constants are not inserted in the state dict (#2558)
Summary: Pull Request resolved: #2558 Currently in `lift_tensor_constant_pass` the lifted constants are being placed in the state dict which is the wrong place for them. They should be inserted in the `constants` section of the `ExportedProgram` which this diff addresses. Reviewed By: angelayi Differential Revision: D55175415 fbshipit-source-id: afe7f4df94f89de6d0c43dab824c498567a6b21f
1 parent 57e3449 commit a624345

File tree

6 files changed

+21
-11
lines changed

6 files changed

+21
-11
lines changed

exir/lowered_backend_module.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,7 @@ def program(self, emit_stacktrace: bool = False) -> Program:
203203
for node in lowered_exported_program.graph.nodes
204204
if (
205205
node.op == "placeholder"
206-
and node.name
207-
not in lowered_exported_program.graph_signature.inputs_to_buffers
208-
and node.name
209-
not in lowered_exported_program.graph_signature.inputs_to_parameters
206+
and node.name in lowered_exported_program.graph_signature.user_inputs
210207
)
211208
]
212209

@@ -230,6 +227,8 @@ def program(self, emit_stacktrace: bool = False) -> Program:
230227
node.name in lowered_exported_program.graph_signature.inputs_to_buffers
231228
or node.name
232229
in lowered_exported_program.graph_signature.inputs_to_parameters
230+
or node.name
231+
in lowered_exported_program.graph_signature.inputs_to_lifted_tensor_constants
233232
):
234233
lowered_exported_program.graph.erase_node(node)
235234

exir/passes/constant_prop_pass.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
8+
from torch._export.utils import (
9+
get_buffer,
10+
get_param,
11+
is_buffer,
12+
is_lifted_tensor_constant,
13+
is_param,
14+
)
915
from torch._guards import detect_fake_mode
1016
from torch.export import ExportedProgram
1117
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
@@ -21,6 +27,7 @@ def is_const(arg, exported_program, const_data_list) -> bool:
2127
elif (
2228
is_param(exported_program, arg)
2329
or is_buffer(exported_program, arg)
30+
or is_lifted_tensor_constant(exported_program, arg)
2431
or arg.name in const_data_list
2532
):
2633
return True
@@ -34,6 +41,8 @@ def get_data(exported_program, arg):
3441
return get_param(exported_program, arg)
3542
elif is_buffer(exported_program, arg):
3643
return get_buffer(exported_program, arg)
44+
elif arg.name in exported_program.constants:
45+
return exported_program.constants[arg.name]
3746
return None
3847

3948

exir/passes/spec_prop_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def update_placeholder_tensor_specs(
7878
node.target in exported_program.graph_signature.inputs_to_buffers
7979
and not _is_mutable_buffer(node, exported_program.graph_signature)
8080
)
81+
or node.target
82+
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
8183
):
8284
spec.const = True
8385

exir/program/_program.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def lift_constant_tensor_pass(ep):
189189
if not isinstance(constant_tensor, torch.Tensor):
190190
continue
191191

192-
constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"
192+
constant_tensor_fqn = f"_lifted_tensor_constant{len(ep.constants)}"
193193

194194
with ep.graph.inserting_before(first_user_input):
195195
# Insert the constant node before the first user input
@@ -209,14 +209,14 @@ def lift_constant_tensor_pass(ep):
209209
# Add the constant as a buffer to the graph signature
210210
lifted_constants.append(
211211
InputSpec(
212-
kind=InputKind.BUFFER,
212+
kind=InputKind.CONSTANT_TENSOR,
213213
arg=TensorArgument(name=const_placeholder_node.name),
214214
target=constant_tensor_fqn,
215-
persistent=True,
215+
persistent=None,
216216
)
217217
)
218218
buffers.append(constant_tensor_fqn)
219-
ep.state_dict[constant_tensor_fqn] = constant_tensor
219+
ep.constants[constant_tensor_fqn] = constant_tensor
220220

221221
new_input_specs = []
222222
for s in graph_signature.input_specs:

exir/tests/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ python_unittest(
251251
"//caffe2:torch",
252252
"//executorch/exir:lib",
253253
"//executorch/exir:pass_base",
254-
"//executorch/exir/backend:backend_api",
255254
"//executorch/exir/backend:backend_details",
256255
"//executorch/exir/backend:compile_spec_schema",
257256
"//executorch/exir/backend:partitioner",
@@ -430,6 +429,7 @@ python_unittest(
430429
],
431430
deps = [
432431
"//caffe2:torch",
432+
"//executorch/exir:dim_order_utils",
433433
"//executorch/exir:lib",
434434
],
435435
)

exir/tests/test_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
11311131

11321132
# Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor
11331133
FileCheck().check_not("_lifted_tensor_constant").check(
1134-
"_prop_tensor_constant1"
1134+
"_prop_tensor_constant0"
11351135
).check_not("executorch_exir_dialects_edge__ops_aten__to_copy_default").run(
11361136
new_ep.graph_module.code
11371137
)

0 commit comments

Comments
 (0)