Skip to content

Commit e5fe065

Browse files
committed
Fix partitioner parameters recognition (#2)
1 parent d4b3e5c commit e5fe065

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

backends/apple/mps/partition/mps_partitioner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from executorch.backends.apple.mps.mps_preprocess import MPSBackend
1111
from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors
12+
from executorch.backends.apple.mps.utils.mps_utils import is_parameter
1213
from executorch.exir.backend.backend_details import CompileSpec
1314
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
1415
generate_partitions_from_list_of_nodes,
@@ -30,8 +31,15 @@
3031
class MPSOperatorSupport(OperatorSupportBase):
3132
def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs):
3233
self.node_visitors = get_node_visitors(edge_program)
34+
self.edge_program = edge_program
3335

3436
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
37+
# Parameters are supported if any of their users are supported
38+
if is_parameter(self.edge_program, node):
39+
return any(
40+
self.is_node_supported(submodules, user) for user in node.users.keys()
41+
)
42+
3543
if node.op != "call_function":
3644
return False
3745

0 commit comments

Comments
 (0)