Remove metadata on delegate getitem nodes #342
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
mcr229 filed an issue which is that the delegate's getitem nodes (the
nodes pointing to each result of the call_delegate call) contain the metadata
of the original nodes, specifically the source_fn metadata. This causes an
issue because if we have 2 calls to to_backend, the first call will partition
torch.nn.Linear using source_fn metadata, and create a call_delegate node along
with getitem calls which now contain the torch.nn.Linear source_fn metadata.
When a second to_backend call comes along, if it also wants to partition based
on torch.nn.Linear source_fn metadata, it will incorrectly partition the
getitem nodes to the delegates made by the first to_backend call.
Implementation wise, this happens because the fuse_as_graphmodule function will
automatically propagate metadata of the nodes being partitioned, to the getitem
nodes. So, we will need to insert an extra pass to remove the metadata on these
nodes. Note that this will also remove the "val" metadata, but we will bring it
back in final the ExportPass() call at the end of to_backend.
Differential Revision: D49264387