Skip to content

Commit 6c3e900

Browse files
cccclaifacebook-github-bot
authored andcommitted
fix coreml partitioner for llama (#2816)
Summary: Fix the constant nodes tagging for coreml partitioner ``` python3 -m examples.models.llama2.export_llama --coreml --use_kv_cache ``` Differential Revision: D55663281
1 parent 4f3c4e6 commit 6c3e900

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch.export.exported_program import ExportedProgram
2222
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2323
from torch.fx.passes.operator_support import OperatorSupportBase
24+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2425

2526
logger = logging.getLogger(__name__)
2627
logger.setLevel(logging.WARNING)
@@ -87,8 +88,20 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
8788
tag = f"tag{partition.id}"
8889
node.meta["delegation_tag"] = tag
8990
partition_tags[tag] = self.delegation_spec
90-
91-
tag_constant_data(exported_program)
91+
node_args = node.args
92+
for arg in node_args:
93+
# tag the data nodes if they're used by coreml backend
94+
if isinstance(arg, torch.fx.Node):
95+
is_attr = (
96+
arg.op == "placeholder"
97+
and (
98+
is_param(exported_program, arg)
99+
or is_buffer(exported_program, arg)
100+
or is_lifted_tensor_constant(exported_program, arg)
101+
)
102+
) or (arg.op == "get_attr")
103+
if is_attr:
104+
arg.meta["delegation_tag"] = tag
92105

93106
return PartitionResult(
94107
tagged_exported_program=exported_program, partition_tags=partition_tags

0 commit comments

Comments
 (0)