Skip to content

Commit dc909be

Browse files
tarun292facebook-github-bot
authored andcommitted
Revert D55175415: Multisect successfully blamed "D55175415: Fix lift_constant_tensor_pass to make sure constants are not inserted in the state dict" for one test failure (#2767)
Summary: Pull Request resolved: #2767 This diff reverts D55175415 D55175415: Fix lift_constant_tensor_pass to make sure constants are not inserted in the state dict by tarun292 causes the following test failure: Tests affected: - [cogwheel:cogwheel_gpu_lowering_cws_sampled_lowering_replay_test#main](https://www.internalfb.com/intern/test/844425001247266/) Here's the Multisect link: https://www.internalfb.com/multisect/4734742 Here are the tasks that are relevant to this breakage: The backout may land if someone accepts it. If this diff has been generated in error, you can Commandeer and Abandon it. Reviewed By: dbort Differential Revision: D55521379 fbshipit-source-id: 7521019d07b244876358c6906d1be607a010e055
1 parent d4b3e5c commit dc909be

File tree

5 files changed

+10
-20
lines changed

5 files changed

+10
-20
lines changed

exir/lowered_backend_module.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,10 @@ def program(self, emit_stacktrace: bool = False) -> Program:
203203
for node in lowered_exported_program.graph.nodes
204204
if (
205205
node.op == "placeholder"
206-
and node.name in lowered_exported_program.graph_signature.user_inputs
206+
and node.name
207+
not in lowered_exported_program.graph_signature.inputs_to_buffers
208+
and node.name
209+
not in lowered_exported_program.graph_signature.inputs_to_parameters
207210
)
208211
]
209212

@@ -227,8 +230,6 @@ def program(self, emit_stacktrace: bool = False) -> Program:
227230
node.name in lowered_exported_program.graph_signature.inputs_to_buffers
228231
or node.name
229232
in lowered_exported_program.graph_signature.inputs_to_parameters
230-
or node.name
231-
in lowered_exported_program.graph_signature.inputs_to_lifted_tensor_constants
232233
):
233234
lowered_exported_program.graph.erase_node(node)
234235

exir/passes/constant_prop_pass.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from torch._export.utils import (
9-
get_buffer,
10-
get_param,
11-
is_buffer,
12-
is_lifted_tensor_constant,
13-
is_param,
14-
)
8+
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
159
from torch._guards import detect_fake_mode
1610
from torch.export import ExportedProgram
1711
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
@@ -27,7 +21,6 @@ def is_const(arg, exported_program, const_data_list) -> bool:
2721
elif (
2822
is_param(exported_program, arg)
2923
or is_buffer(exported_program, arg)
30-
or is_lifted_tensor_constant(exported_program, arg)
3124
or arg.name in const_data_list
3225
):
3326
return True
@@ -41,8 +34,6 @@ def get_data(exported_program, arg):
4134
return get_param(exported_program, arg)
4235
elif is_buffer(exported_program, arg):
4336
return get_buffer(exported_program, arg)
44-
elif arg.name in exported_program.constants:
45-
return exported_program.constants[arg.name]
4637
return None
4738

4839

exir/passes/spec_prop_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ def update_placeholder_tensor_specs(
7878
node.target in exported_program.graph_signature.inputs_to_buffers
7979
and not _is_mutable_buffer(node, exported_program.graph_signature)
8080
)
81-
or node.target
82-
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
8381
):
8482
spec.const = True
8583

exir/program/_program.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def lift_constant_tensor_pass(ep):
189189
if not isinstance(constant_tensor, torch.Tensor):
190190
continue
191191

192-
constant_tensor_fqn = f"_lifted_tensor_constant{len(ep.constants)}"
192+
constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"
193193

194194
with ep.graph.inserting_before(first_user_input):
195195
# Insert the constant node before the first user input
@@ -209,14 +209,14 @@ def lift_constant_tensor_pass(ep):
209209
# Add the constant as a buffer to the graph signature
210210
lifted_constants.append(
211211
InputSpec(
212-
kind=InputKind.CONSTANT_TENSOR,
212+
kind=InputKind.BUFFER,
213213
arg=TensorArgument(name=const_placeholder_node.name),
214214
target=constant_tensor_fqn,
215-
persistent=None,
215+
persistent=True,
216216
)
217217
)
218218
buffers.append(constant_tensor_fqn)
219-
ep.constants[constant_tensor_fqn] = constant_tensor
219+
ep.state_dict[constant_tensor_fqn] = constant_tensor
220220

221221
new_input_specs = []
222222
for s in graph_signature.input_specs:

exir/tests/test_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
11311131

11321132
# Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor
11331133
FileCheck().check_not("_lifted_tensor_constant").check(
1134-
"_prop_tensor_constant0"
1134+
"_prop_tensor_constant1"
11351135
).check_not("executorch_exir_dialects_edge__ops_aten__to_copy_default").run(
11361136
new_ep.graph_module.code
11371137
)

0 commit comments

Comments
 (0)