Skip to content

Commit e802715

Browse files
cccclaifacebook-github-bot
authored andcommitted
fix coreml partitioner for llama
Differential Revision: D55663281
1 parent 4b0ed91 commit e802715

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
PartitionResult,
1919
)
2020
from executorch.exir.backend.utils import tag_constant_data
21+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2122
from torch.export.exported_program import ExportedProgram
2223
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2324
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -88,7 +89,25 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
8889
node.meta["delegation_tag"] = tag
8990
partition_tags[tag] = self.delegation_spec
9091

91-
tag_constant_data(exported_program)
92+
is_attr = (
93+
node.op == "placeholder"
94+
and (
95+
is_param(exported_program, node)
96+
or is_buffer(exported_program, node)
97+
or is_lifted_tensor_constant(exported_program, node)
98+
)
99+
)
100+
# if all users of const/param/buffer nodes are partitioned then partition
101+
if is_attr:
102+
user_tags = set()
103+
for user in node.users:
104+
user_tag = user.meta.get("delegation_tag", None)
105+
if user_tag is not None:
106+
user_tags.add(user_tag)
107+
if len(user_tags) >= 1:
108+
# There are more than one user tag, just pick one and there is a pass later to copy the constant data
109+
node.meta["delegation_tag"] = user_tags.pop()
110+
92111

93112
return PartitionResult(
94113
tagged_exported_program=exported_program, partition_tags=partition_tags

0 commit comments

Comments
 (0)