Skip to content

[MPS] Fix partitioner parameters recognition #2777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backends/apple/mps/partition/mps_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from executorch.backends.apple.mps.mps_preprocess import MPSBackend
from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors
from executorch.backends.apple.mps.utils.mps_utils import is_parameter
from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_partitions_from_list_of_nodes,
Expand All @@ -30,8 +31,15 @@
class MPSOperatorSupport(OperatorSupportBase):
def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs):
self.node_visitors = get_node_visitors(edge_program)
self.edge_program = edge_program

def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
# Parameters are supported if any of their users are supported
if is_parameter(self.edge_program, node):
Copy link
Contributor

@cccclai cccclai Mar 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, after reading the pr again, why does the following lines mean it's parameter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here by it's parameter we're checking only if it's a constant (lifter buffer)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a constant (lifter buffer)

right, I'm trying to find the logic to determine if it's a constant. This line doesn't look like so.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're checking if it's a constant if is_param or is_buffer are true (or is_get_attr is set to true - which I think it's not the case anymore after moving to lifted graphs)

def is_parameter(exp_prog: torch.export.ExportedProgram, node: torch.fx.Node) -> bool:
    """
    Check if a node is a lifted parameter (static data like weights and bias are
    are supplied as inputs to the graph.

    Args:
        edge_program (torch.export.ExportedProgram): _description_
        node (torch.fx.Node): _description_

    Returns:
        bool: _description_
    """
    return is_get_attr(node) or is_param(exp_prog, node) or is_buffer(exp_prog, node)

return any(
self.is_node_supported(submodules, user) for user in node.users.keys()
)

if node.op != "call_function":
return False

Expand Down