Skip to content

Commit f471556

Browse files
YifanShenSZyifan_shen3
andauthored
Partition Mutable Buffer as Core ML State (#5165)
* partition mutable buffer to coreml state * delegate llama mutable buffer to coreml * fix lint * support embedding quantize * try fix CI: 1. pin coremltools 8.0b2; 2. refrain from defaulting stateful llama until CI machine upgraded to MacOS 15 * address review comments: 1. add arg help info; 2. add mutable buffer partition log * fix CI: executorch example model test env is using older transformers, that does not support numpy 2.0 --------- Co-authored-by: yifan_shen3 <[email protected]>
1 parent c5a385e commit f471556

File tree

6 files changed

+124
-18
lines changed

6 files changed

+124
-18
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Partitioner,
1818
PartitionResult,
1919
)
20-
from executorch.exir.backend.utils import tag_constant_data
20+
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
2121
from torch.export.exported_program import ExportedProgram
2222
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2323
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -61,6 +61,7 @@ def __init__(
6161
self,
6262
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
6363
compile_specs: Optional[List[CompileSpec]] = None,
64+
take_over_mutable_buffer: Optional[bool] = True,
6465
) -> None:
6566
if skip_ops_for_coreml_delegation is None:
6667
skip_ops_for_coreml_delegation = []
@@ -69,6 +70,7 @@ def __init__(
6970
backend_id=CoreMLBackend.__name__,
7071
compile_specs=compile_specs if compile_specs is not None else [],
7172
)
73+
self.take_over_mutable_buffer = take_over_mutable_buffer
7274

7375
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
7476
# Run the CapabilityBasedPartitioner to return the largest possible
@@ -89,6 +91,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
8991
partition_tags[tag] = self.delegation_spec
9092

9193
tag_constant_data(exported_program)
94+
if self.take_over_mutable_buffer:
95+
logger.info(
96+
"Core ML partitioner will take over torch mutable buffer as Core ML state, "
97+
"so if your model contains mutable buffer, "
98+
"then you will need MacOS15+/iOS18+ to execute. "
99+
"If you want your mutable buffer model to be compatible with older OS, "
100+
"then please set `take_over_mutable_buffer=False`"
101+
)
102+
tag_mutated_buffer(exported_program)
92103

93104
return PartitionResult(
94105
tagged_exported_program=exported_program, partition_tags=partition_tags

backends/apple/coreml/scripts/install_requirements.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ rm -rf "$COREML_DIR_PATH/third-party"
2424
mkdir "$COREML_DIR_PATH/third-party"
2525

2626
echo "${green}ExecuTorch: Cloning coremltools."
27-
git clone --depth 1 --branch 8.0b1 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH
27+
git clone --depth 1 --branch 8.0b2 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH
2828
cd $COREMLTOOLS_DIR_PATH
2929

3030
STATUS=$?
@@ -47,6 +47,11 @@ cmake --build "$COREMLTOOLS_DIR_PATH/build" --parallel
4747

4848
echo "${green}ExecuTorch: Installing coremltools."
4949
pip install "$COREMLTOOLS_DIR_PATH"
50+
# CoreMLTools have started supporting numpy 2.0,
51+
# but ExecuTorch example model test env is still using older transformers,
52+
# so for now we will need to downgrade numpy to 1.x
53+
# TODO: Remove this numpy downgrade once later transformers starts to be used
54+
pip install numpy==1.26.4
5055
STATUS=$?
5156
if [ $STATUS -ne 0 ]; then
5257
echo "${red}ExecuTorch: Failed to install coremltools."

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44

55
import unittest
66

7+
import coremltools as ct
8+
79
import executorch.exir
810

911
import torch
1012
import torchvision
1113

14+
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1215
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
1316

1417

@@ -86,8 +89,54 @@ def test_vit_skip_conv(self):
8689
if node.op == "call_function"
8790
] == total
8891

92+
def test_buffer(self):
93+
embedding_dim = 3
94+
max_seq_len = 2
95+
96+
class Model(torch.nn.Module):
97+
def __init__(self):
98+
super().__init__()
99+
self.register_buffer(
100+
"cache",
101+
torch.zeros((max_seq_len, embedding_dim), dtype=torch.float32),
102+
)
103+
104+
def forward(self, q, k_val, input_pos):
105+
q_T = q.transpose(0, 1)
106+
k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val)
107+
attn = k.mm(q_T)
108+
return attn
109+
110+
model = Model()
111+
model.eval()
112+
113+
q = torch.randn((1, embedding_dim))
114+
k_val = torch.randn((1, embedding_dim))
115+
input_pos = torch.tensor([0])
116+
example_inputs = (q, k_val, input_pos)
117+
exir_program_aten = torch.export.export(model, example_inputs)
118+
119+
compile_specs = CoreMLBackend.generate_compile_specs(
120+
minimum_deployment_target=ct.target.iOS18
121+
)
122+
partitioner = CoreMLPartitioner(compile_specs=compile_specs)
123+
edge_program_manager = executorch.exir.to_edge(
124+
exir_program_aten, compile_config=self.edge_compile_config
125+
)
126+
delegated_program_manager = edge_program_manager.to_backend(partitioner)
127+
128+
assert [
129+
node.target.__name__
130+
for node in delegated_program_manager.exported_program().graph.nodes
131+
if node.op == "call_function"
132+
] == [
133+
"executorch_call_delegate",
134+
"getitem",
135+
]
136+
89137

90138
if __name__ == "__main__":
91139
test_runner = TestCoreMLPartitioner()
92140
test_runner.test_add_sub_skip_mm()
93141
test_runner.test_vit_skip_conv()
142+
test_runner.test_buffer()

examples/models/llama2/export_llama_lib.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,11 @@ def build_args_parser() -> argparse.ArgumentParser:
288288
parser.add_argument("-V", "--vulkan", action="store_true")
289289
parser.add_argument("--mps", action="store_true")
290290
parser.add_argument("--coreml", action="store_true")
291+
parser.add_argument(
292+
"--coreml-enable-state",
293+
action="store_true",
294+
help="This option is only for coreml, and is only supported for MacOS15+/iOS18+",
295+
)
291296
parser.add_argument(
292297
"--qnn",
293298
action="store_true",
@@ -523,7 +528,9 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
523528

524529
if args.coreml:
525530
coreml_partitioner = get_coreml_partitioner(
526-
args.use_kv_cache, args.pt2e_quantize
531+
args.use_kv_cache and args.coreml_enable_state,
532+
args.embedding_quantize,
533+
args.pt2e_quantize,
527534
)
528535
partitioners.append(coreml_partitioner)
529536
modelname = f"coreml_{modelname}"

exir/backend/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,40 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
383383
node.meta["delegation_tag"] = user_tags.pop()
384384

385385

386+
def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
387+
"""
388+
Util function for partitioners. This function tags the mutated buffer nodes
389+
whose users all belong within the same partition. This should be called after tagging all other nodes.
390+
Any buffer which is used as input to a subgraph, will be tagged with the same tag as that
391+
subgraph. Throw error when buffers is used across different partitions. That is the
392+
underlying data will be owned by multiple delegates.
393+
"""
394+
for node in edge_program.graph.nodes:
395+
# Determine whether this node is a mutated buffer
396+
is_mutated_buffer_node = False
397+
if node.op == "placeholder" and is_buffer(edge_program, node):
398+
for node_user in node.users:
399+
if node_user.name in edge_program.graph_signature.buffers_to_mutate:
400+
is_mutated_buffer_node = True
401+
break
402+
# This node is mutated buffer, tag it
403+
if is_mutated_buffer_node:
404+
user_tags = set()
405+
for user in node.users:
406+
user_tag = user.meta.get("delegation_tag", None)
407+
if user_tag is not None:
408+
user_tags.add(user_tag)
409+
if len(user_tags) > 1:
410+
logging.info(
411+
f"The data node is used across multiple partitions, including {user_tags}. "
412+
"If the data is too large and it's not preferred to copy, please tag the "
413+
"constant node like node.['no_copy'] = True and they won't be copied."
414+
)
415+
# tag the data node with the same tag as the last user
416+
if len(user_tags) > 0:
417+
node.meta["delegation_tag"] = user_tags.pop()
418+
419+
386420
# TODO - style: use templated types
387421
class DelegateMappingBuilder:
388422
"""

extension/llm/export/partitioner_lib.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,10 @@ def get_mps_partitioner(use_kv_cache: bool = False):
5656

5757

5858
def get_coreml_partitioner(
59-
use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None
59+
enable_state: bool = False,
60+
embedding_quantize: Optional[str] = None,
61+
pt2e_quantize: Optional[str] = None,
6062
):
61-
assert (
62-
use_kv_cache is True
63-
), "CoreML backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment"
6463
try:
6564
import coremltools as ct
6665
from executorch.backends.apple.coreml.compiler import ( # pyre-ignore
@@ -75,22 +74,22 @@ def get_coreml_partitioner(
7574
)
7675

7776
minimum_deployment_target = ct.target.iOS15
78-
# In Core ML, quantization in introduced in iOS 16
79-
if pt2e_quantize is not None:
77+
# In Core ML, stateful execution is introduced in iOS 18
78+
if enable_state:
79+
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
80+
# In Core ML, quantization is introduced in iOS 16
81+
if embedding_quantize is not None or pt2e_quantize is not None:
8082
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16)
8183
# In Core ML, 8-bit activation quantization is introduced in iOS 17
82-
if pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"):
84+
if (
85+
embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 8
86+
) or pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"):
8387
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17)
8488
# In Core ML, 4-bit weight compression is introduced in iOS 18
85-
if pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"):
89+
if (
90+
embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 4
91+
) or pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"):
8692
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
87-
# In Core ML, stateful execution is introduced in iOS 18
88-
# TODO (https://github.com/pytorch/executorch/issues/4209)
89-
# For now, since mutable buffer is kept in executorch runtime,
90-
# state is out of place and can be handled by older iOS.
91-
# Once mutable buffer can be handed over to delegate, i.e. state becomes in-place, we will have
92-
# if use_kv_cache:
93-
# minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
9493

9594
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
9695
minimum_deployment_target=minimum_deployment_target,
@@ -101,6 +100,7 @@ def get_coreml_partitioner(
101100
)
102101
return CoreMLPartitioner( # pyre-fixme[16]
103102
compile_specs=compile_specs,
103+
take_over_mutable_buffer=enable_state,
104104
)
105105

106106

0 commit comments

Comments
 (0)