Skip to content

Remove metadata on delegate getitem nodes #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator
import unittest
from typing import Dict, List

Expand Down Expand Up @@ -816,6 +817,16 @@ def forward(self, a, x, b):
executorch_prog.exported_program = to_backend(
ep.exported_program, AddMulPartitionerDemo
)

for node in executorch_prog.exported_program.graph.nodes:
if node.op == "call_function" and node.target is executorch_call_delegate:
for user in node.users:
self.assertTrue(
user.op == "call_function" and user.target == operator.getitem
)
self.assertTrue(user.meta.get("source_fn", None) is None)
self.assertTrue(user.meta.get("nn_module_stack", None) is None)

executorch_prog = executorch_prog.to_executorch(
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
)
Expand Down Expand Up @@ -864,7 +875,7 @@ def __init__(self):

def forward(self, x, y):
x = self.add_one(x) * y
return self.add_one(x)
return self.add_one(x), self.add_one(y)

inputs = (torch.randn(1, 3), torch.randn(1, 3))
orig_res = Model()(*inputs)
Expand All @@ -873,21 +884,32 @@ def forward(self, x, y):
executorch_prog.exported_program = to_backend(
ep.exported_program, AddAttributePartitionerDemo
)

for node in executorch_prog.exported_program.graph.nodes:
if node.op == "call_function" and node.target is executorch_call_delegate:
for user in node.users:
self.assertTrue(
user.op == "call_function" and user.target == operator.getitem
)
self.assertTrue(user.meta.get("source_fn", None) is None)
self.assertTrue(user.meta.get("nn_module_stack", None) is None)

executorch_prog = executorch_prog.to_executorch(
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
)

# Check the delegated submodules
lowered_submodules = get_lowered_submodules(executorch_prog.dump_graph_module())
self.assertEqual(len(lowered_submodules), 2)
for _, lowered_submodule, _ in lowered_submodules:
# Attributes should be stored in the lowered module
self.check_delegate_input(lowered_submodule, 1)
# Attributes should be stored in the lowered module
self.check_delegate_input(lowered_submodules[0][1], 1)
self.check_delegate_input(lowered_submodules[1][1], 2)

executorch_prog.buffer

new_res = executorch_prog.dump_graph_module()(*inputs)
self.assertTrue(torch.allclose(orig_res, new_res[0]))
self.assertTrue(torch.allclose(orig_res[0], new_res[0]))
self.assertTrue(torch.allclose(orig_res[1], new_res[1]))

def test_bad_partitioner(self):
"""
Expand Down
34 changes: 27 additions & 7 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,17 +459,37 @@ def create_submodule_from_nodes(
_fixup_output_node(sub_gm)

gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
submodule_node = None
for node in gm.graph.nodes:
if node.op == "call_module":
if node.target == submodule_name:
submodule_node = node
else:
raise RuntimeError(
f"The submodule created with nodes {node_list} did not form \
one fully contained subgraph. Check that these nodes form a \
fully contained graph. Partitioned graph: {gm.graph}."
)

if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor):
# If the original output is a single tensor, it has been
# pytree.tree_flatten-ed to be a singleton list, so we want to replace
# all uses with a getitem call to the 0th index of the result
for node in gm.graph.nodes:
if node.op == "call_module":
with gm.graph.inserting_after(node):
proxy_out = torch.fx.Proxy(node)[0].node # type: ignore[index]
node.replace_all_uses_with(proxy_out, propagate_meta=True)
# Reset the args since it was overwritten in the previous line
proxy_out.args = (node, 0)
with gm.graph.inserting_after(submodule_node):
proxy_out = torch.fx.Proxy(submodule_node)[0].node # type: ignore[index]
submodule_node.replace_all_uses_with(proxy_out)
proxy_out.meta["val"] = submodule_node.meta["val"]
# Reset the args since it was overwritten in the previous line
proxy_out.args = (submodule_node, 0)
else:
# fuse_as_graphmodule will automatically propagate the metadata of the
# partition's last node to the getitem nodes that appear after the
# call_module node. However, in the case of delegation we do not want
# these getitem nodes to contain irrelevant previous metadata
# (ex. source_fn, # nn_module_stack)
for user_node in submodule_node.users:
user_node.meta.pop("nn_module_stack", None)
user_node.meta.pop("source_fn", None)

erase_nodes(gm, sorted_nodes)

Expand Down