Skip to content

Remove metadata on delegate getitem nodes #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

angelayi
Copy link
Contributor

Summary:
mcr229 filed an issue which is that the delegate's getitem nodes (the
nodes pointing to each result of the call_delegate call) contain the metadata
of the original nodes, specifically the source_fn metadata. This causes an
issue because if we have 2 calls to to_backend, the first call will partition
torch.nn.Linear using source_fn metadata, and create a call_delegate node along
with getitem calls which now contain the torch.nn.Linear source_fn metadata.
When a second to_backend call comes along, if it also wants to partition based
on torch.nn.Linear source_fn metadata, it will incorrectly partition the
getitem nodes to the delegates made by the first to_backend call.

Implementation wise, this happens because the fuse_as_graphmodule function will
automatically propagate metadata of the nodes being partitioned, to the getitem
nodes. So, we will need to insert an extra pass to remove the metadata on these
nodes. Note that this will also remove the "val" metadata, but we will bring it
back in final the ExportPass() call at the end of to_backend.

Differential Revision: D49264387

@facebook-github-bot facebook-github-bot added CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported labels Sep 14, 2023
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49264387

Summary:

mcr229 filed an issue which is that the delegate's getitem nodes (the
nodes pointing to each result of the call_delegate call) contain the metadata
of the original nodes, specifically the source_fn metadata. This causes an
issue because if we have 2 calls to to_backend, the first call will partition
torch.nn.Linear using source_fn metadata, and create a call_delegate node along
with getitem calls which now contain the torch.nn.Linear source_fn metadata.
When a second to_backend call comes along, if it also wants to partition based
on torch.nn.Linear source_fn metadata, it will incorrectly partition the
getitem nodes to the delegates made by the first to_backend call.

Implementation wise, this happens because the fuse_as_graphmodule function will
automatically propagate metadata of the nodes being partitioned, to the getitem
nodes. So, we will need to insert an extra pass to remove the metadata on these
nodes. Note that this will also remove the "val" metadata, but we will bring it
back in final the ExportPass() call at the end of to_backend.

Reviewed By: digantdesai, cccclai

Differential Revision: D49264387
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49264387

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in d301047.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants