-
Notifications
You must be signed in to change notification settings - Fork 608
[MPS] Fix partitioner parameters recognition #2777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Fix partitioner parameters recognition #2777
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/2777
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit e5fe065 with merge base d4b3e5c ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
||
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: | ||
# Parameters are supported if any of their users are supported | ||
if is_parameter(self.edge_program, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, after reading the pr again, why does the following lines mean it's parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here by it's parameter we're checking only if it's a constant (lifter buffer)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a constant (lifter buffer)
right, I'm trying to find the logic to determine if it's a constant. This line doesn't look like so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're checking if it's a constant if is_param
or is_buffer
are true (or is_get_attr
is set to true - which I think it's not the case anymore after moving to lifted graphs)
def is_parameter(exp_prog: torch.export.ExportedProgram, node: torch.fx.Node) -> bool:
"""
Check if a node is a lifted parameter (static data like weights and bias are
are supplied as inputs to the graph.
Args:
edge_program (torch.export.ExportedProgram): _description_
node (torch.fx.Node): _description_
Returns:
bool: _description_
"""
return is_get_attr(node) or is_param(exp_prog, node) or is_buffer(exp_prog, node)
@pytorchbot cherry-pick --onto release/0.2 -c critical |
Cherry picking #2777Command
Details for Dev Infra teamRaised by workflow job |
Fixes parameters recognition for MPS backend
Testing:
cc @cccclai, @shoumikhin