Skip to content

Commit 498f249

Browse files
authored
Fix output spec + insert clone for constant_prop_pass.
Differential Revision: D75473310 Pull Request resolved: #11209
1 parent b5567be commit 498f249

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,37 @@ def create_constant_nodes_and_return_specs(
295295
return name_to_spec_dict
296296

297297

298+
def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
299+
"""
300+
Update the output node and output specs in the exported program.
301+
In case a constant node is used as output, we replace it with a clone of the constant node.
302+
"""
303+
# Dict [node.name -> InputSpec]
304+
updated_constant_placeholders = get_constant_placeholder_dict(exported_program)
305+
output = exported_program.graph.find_nodes(op="output")[0]
306+
output_nodes = cast(list[torch.fx.Node], list(output.args[0]))
307+
output_specs = exported_program.graph_signature.output_specs
308+
assert len(output_nodes) == len(output_specs)
309+
310+
for i in range(len(output_specs)):
311+
out_node = output_nodes[i]
312+
if out_node not in updated_constant_placeholders:
313+
continue
314+
315+
with exported_program.graph.inserting_after(out_node):
316+
new_node = exported_program.graph.call_function(
317+
exir_ops.edge.aten.clone.default, (out_node,)
318+
)
319+
assert "val" in out_node.meta
320+
new_node.meta["val"] = out_node.meta["val"]
321+
output_nodes[i] = new_node
322+
323+
# Update the constant-propagated output node.
324+
output_specs[i].arg = TensorArgument(name=output_nodes[i].name)
325+
326+
output.args = (output_nodes,)
327+
328+
298329
def constant_prop_pass(
299330
exported_program: ExportedProgram,
300331
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
@@ -341,12 +372,12 @@ def constant_prop_pass(
341372

342373
# Generate new input spec.
343374
new_input_specs = []
344-
for node in exported_program.graph.nodes:
345-
if node.op != "placeholder":
346-
continue
375+
for node in exported_program.graph.find_nodes(op="placeholder"):
347376
new_input_specs.append(name_to_spec_dict[node.name])
348377
exported_program.graph_signature.input_specs = new_input_specs
349378

379+
_update_output_node_and_specs(exported_program)
380+
350381
# Cleanup the graph.
351382
exported_program.graph.eliminate_dead_code()
352383
exported_program.graph_module.recompile()

exir/tests/test_passes.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,34 @@ def forward(self, x):
10261026
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"
10271027
).run(gm.code)
10281028

1029+
def test_constant_prop_for_output(self) -> None:
1030+
class Add(torch.nn.Module):
1031+
def forward(self) -> torch.Tensor:
1032+
return torch.add(torch.tensor(3), torch.tensor(5))
1033+
1034+
add = Add()
1035+
1036+
edge = to_edge(
1037+
export(add, (), strict=True),
1038+
compile_config=EdgeCompileConfig(_skip_dim_order=False),
1039+
)
1040+
# Check there is a lifted tensor followed by a to_copy node
1041+
FileCheck().check("c_lifted_tensor_0").check("c_lifted_tensor_1").run(
1042+
edge.exported_program().graph_module.code
1043+
)
1044+
1045+
edge._edge_programs["forward"] = constant_prop_pass(
1046+
edge.exported_program("forward")
1047+
)
1048+
1049+
# Check (c_lifted_tensor_*) nodes are all replaced by _prop_tensor_constant.
1050+
FileCheck().check_not("c_lifted_tensor_").check("_prop_tensor_constant").run(
1051+
edge.exported_program().graph_module.code
1052+
)
1053+
# Validate that the program successfully passes validation to executorch:
1054+
edge.exported_program()._validate()
1055+
edge.to_executorch()
1056+
10291057
def test_constant_prop_pass_for_add(self) -> None:
10301058
class Add(torch.nn.Module):
10311059
def forward(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)