Skip to content

Commit 4d4f37b

Browse files
cccclaifacebook-github-bot
authored andcommitted
fix coreml partitioner for llama (#2816)
Summary: Fix the constant nodes tagging for coreml partitioner ``` python3 -m examples.models.llama2.export_llama --coreml --use_kv_cache ``` Differential Revision: D55663281
1 parent 495cfd0 commit 4d4f37b

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch.export.exported_program import ExportedProgram
2222
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2323
from torch.fx.passes.operator_support import OperatorSupportBase
24+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2425

2526
logger = logging.getLogger(__name__)
2627
logger.setLevel(logging.WARNING)
@@ -73,7 +74,6 @@ def __init__(
7374
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
7475
# Run the CapabilityBasedPartitioner to return the largest possible
7576
# subgraphs containing the nodes with the tags
76-
logger.info("CoreMLPartitioner::partition")
7777
partition_tags = {}
7878

7979
capability_partitioner = CapabilityBasedPartitioner(
@@ -87,8 +87,20 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
8787
tag = f"tag{partition.id}"
8888
node.meta["delegation_tag"] = tag
8989
partition_tags[tag] = self.delegation_spec
90-
91-
tag_constant_data(exported_program)
90+
node_args = node.args
91+
for arg in node_args:
92+
# tag the data nodes if they're used by coreml backend
93+
if isinstance(arg, torch.fx.Node):
94+
is_attr = (
95+
arg.op == "placeholder"
96+
and (
97+
is_param(exported_program, arg)
98+
or is_buffer(exported_program, arg)
99+
or is_lifted_tensor_constant(exported_program, arg)
100+
)
101+
) or (arg.op == "get_attr")
102+
if is_attr:
103+
arg.meta["delegation_tag"] = tag
92104

93105
return PartitionResult(
94106
tagged_exported_program=exported_program, partition_tags=partition_tags

0 commit comments

Comments
 (0)