Skip to content

Commit 73192aa

Browse files
cccclaifacebook-github-bot
authored andcommitted
fix coreml partitioner for llama (#2816)
Summary: Fix the constant nodes tagging for coreml partitioner ``` python3 -m examples.models.llama2.export_llama --coreml --use_kv_cache ``` Differential Revision: D55663281
1 parent 41290b4 commit 73192aa

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Partitioner,
1818
PartitionResult,
1919
)
20-
from executorch.exir.backend.utils import tag_constant_data
20+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2121
from torch.export.exported_program import ExportedProgram
2222
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2323
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -87,8 +87,19 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
8787
tag = f"tag{partition.id}"
8888
node.meta["delegation_tag"] = tag
8989
partition_tags[tag] = self.delegation_spec
90-
91-
tag_constant_data(exported_program)
90+
node_args = node.args
91+
for arg in node_args:
92+
# tag the data nodes if they're used by coreml backend
93+
if isinstance(arg, torch.fx.Node):
94+
if (
95+
arg.op == "placeholder"
96+
and (
97+
is_param(exported_program, arg)
98+
or is_buffer(exported_program, arg)
99+
or is_lifted_tensor_constant(exported_program, arg)
100+
)
101+
):
102+
arg.meta["delegation_tag"] = tag
92103

93104
return PartitionResult(
94105
tagged_exported_program=exported_program, partition_tags=partition_tags

0 commit comments

Comments
 (0)