Skip to content

Commit 6ee7b48

Browse files
angelayifacebook-github-bot
authored andcommitted
Remove metadata on delegate getitem nodes (#342)
Summary: mcr229 filed an issue which is that the delegate's getitem nodes (the nodes pointing to each result of the call_delegate call) contain the metadata of the original nodes, specifically the source_fn metadata. This causes an issue because if we have 2 calls to to_backend, the first call will partition torch.nn.Linear using source_fn metadata, and create a call_delegate node along with getitem calls which now contain the torch.nn.Linear source_fn metadata. When a second to_backend call comes along, if it also wants to partition based on torch.nn.Linear source_fn metadata, it will incorrectly partition the getitem nodes to the delegates made by the first to_backend call. Implementation wise, this happens because the fuse_as_graphmodule function will automatically propagate metadata of the nodes being partitioned, to the getitem nodes. So, we will need to insert an extra pass to remove the metadata on these nodes. Note that this will also remove the "val" metadata, but we will bring it back in final the ExportPass() call at the end of to_backend. Reviewed By: digantdesai, cccclai Differential Revision: D49264387
1 parent 7f395fd commit 6ee7b48

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

exir/backend/test/test_backends.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import operator
78
import unittest
89
from typing import Dict, List
910

@@ -816,6 +817,16 @@ def forward(self, a, x, b):
816817
executorch_prog.exported_program = to_backend(
817818
ep.exported_program, AddMulPartitionerDemo
818819
)
820+
821+
for node in executorch_prog.exported_program.graph.nodes:
822+
if node.op == "call_function" and node.target is executorch_call_delegate:
823+
for user in node.users:
824+
self.assertTrue(
825+
user.op == "call_function" and user.target == operator.getitem
826+
)
827+
self.assertTrue(user.meta.get("source_fn", None) is None)
828+
self.assertTrue(user.meta.get("nn_module_stack", None) is None)
829+
819830
executorch_prog = executorch_prog.to_executorch(
820831
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
821832
)
@@ -864,7 +875,7 @@ def __init__(self):
864875

865876
def forward(self, x, y):
866877
x = self.add_one(x) * y
867-
return self.add_one(x)
878+
return self.add_one(x), self.add_one(y)
868879

869880
inputs = (torch.randn(1, 3), torch.randn(1, 3))
870881
orig_res = Model()(*inputs)
@@ -873,21 +884,32 @@ def forward(self, x, y):
873884
executorch_prog.exported_program = to_backend(
874885
ep.exported_program, AddAttributePartitionerDemo
875886
)
887+
888+
for node in executorch_prog.exported_program.graph.nodes:
889+
if node.op == "call_function" and node.target is executorch_call_delegate:
890+
for user in node.users:
891+
self.assertTrue(
892+
user.op == "call_function" and user.target == operator.getitem
893+
)
894+
self.assertTrue(user.meta.get("source_fn", None) is None)
895+
self.assertTrue(user.meta.get("nn_module_stack", None) is None)
896+
876897
executorch_prog = executorch_prog.to_executorch(
877898
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
878899
)
879900

880901
# Check the delegated submodules
881902
lowered_submodules = get_lowered_submodules(executorch_prog.dump_graph_module())
882903
self.assertEqual(len(lowered_submodules), 2)
883-
for _, lowered_submodule, _ in lowered_submodules:
884-
# Attributes should be stored in the lowered module
885-
self.check_delegate_input(lowered_submodule, 1)
904+
# Attributes should be stored in the lowered module
905+
self.check_delegate_input(lowered_submodules[0][1], 1)
906+
self.check_delegate_input(lowered_submodules[1][1], 2)
886907

887908
executorch_prog.buffer
888909

889910
new_res = executorch_prog.dump_graph_module()(*inputs)
890-
self.assertTrue(torch.allclose(orig_res, new_res[0]))
911+
self.assertTrue(torch.allclose(orig_res[0], new_res[0]))
912+
self.assertTrue(torch.allclose(orig_res[1], new_res[1]))
891913

892914
def test_bad_partitioner(self):
893915
"""

exir/lowered_backend_module.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -459,17 +459,37 @@ def create_submodule_from_nodes(
459459
_fixup_output_node(sub_gm)
460460

461461
gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
462+
submodule_node = None
463+
for node in gm.graph.nodes:
464+
if node.op == "call_module":
465+
if node.target == submodule_name:
466+
submodule_node = node
467+
else:
468+
raise RuntimeError(
469+
f"The submodule created with nodes {node_list} did not form \
470+
one fully contained subgraph. Check that these nodes form a \
471+
fully contained graph. Partitioned graph: {gm.graph}."
472+
)
473+
462474
if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor):
463475
# If the original output is a single tensor, it has been
464476
# pytree.tree_flatten-ed to be a singleton list, so we want to replace
465477
# all uses with a getitem call to the 0th index of the result
466-
for node in gm.graph.nodes:
467-
if node.op == "call_module":
468-
with gm.graph.inserting_after(node):
469-
proxy_out = torch.fx.Proxy(node)[0].node # type: ignore[index]
470-
node.replace_all_uses_with(proxy_out, propagate_meta=True)
471-
# Reset the args since it was overwritten in the previous line
472-
proxy_out.args = (node, 0)
478+
with gm.graph.inserting_after(submodule_node):
479+
proxy_out = torch.fx.Proxy(submodule_node)[0].node # type: ignore[index]
480+
submodule_node.replace_all_uses_with(proxy_out)
481+
proxy_out.meta["val"] = submodule_node.meta["val"]
482+
# Reset the args since it was overwritten in the previous line
483+
proxy_out.args = (submodule_node, 0)
484+
else:
485+
# fuse_as_graphmodule will automatically propagate the metadata of the
486+
# partition's last node to the getitem nodes that appear after the
487+
# call_module node. However, in the case of delegation we do not want
488+
# these getitem nodes to contain irrelevant previous metadata
489+
# (ex. source_fn, # nn_module_stack)
490+
for user_node in submodule_node.users:
491+
user_node.meta.pop("nn_module_stack", None)
492+
user_node.meta.pop("source_fn", None)
473493

474494
erase_nodes(gm, sorted_nodes)
475495

0 commit comments

Comments
 (0)