Skip to content

Commit 1793c4a

Browse files
yifan_shen3facebook-github-bot
authored andcommitted
Preserve SDPA for CoreML (#5258)
Summary: ## Motivation Starting from iOS18, CoreML has added SDPA op, so there is no longer need to decompose torch SDPA ## Solution Following #3483, add `ops_not_to_decompose` in CoreML partitioner, then use `to_edge_transform_and_lower` API in llama export Pull Request resolved: #5258 Reviewed By: kirklandsign Differential Revision: D62550916 Pulled By: cccclai fbshipit-source-id: bc238a04ea9eb38341157cd2388b7cd17a506bf1
1 parent 7998b7f commit 1793c4a

File tree

3 files changed

+152
-3
lines changed

3 files changed

+152
-3
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@
5757
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
5858
from .source_transformation.sdpa import (
5959
replace_causal_mask,
60+
replace_kv_cache_with_coreml_kv_cache,
6061
replace_kv_cache_with_simple_kv_cache,
62+
replace_sdpa_with_coreml_sdpa,
6163
replace_sdpa_with_custom_op,
6264
replace_sdpa_with_flex_sdpa,
6365
replace_sdpa_with_simple_sdpa,
@@ -304,6 +306,11 @@ def build_args_parser() -> argparse.ArgumentParser:
304306
action="store_true",
305307
help="This option is only for coreml, and is only supported for MacOS15+/iOS18+",
306308
)
309+
parser.add_argument(
310+
"--coreml-preserve-sdpa",
311+
action="store_true",
312+
help="This option is only for coreml: Preserve sdpa in torch edge program to use coreml iOS18.sdpa op",
313+
)
307314
parser.add_argument(
308315
"--coreml-quantize",
309316
default=None,
@@ -527,6 +534,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
527534
if args.coreml:
528535
coreml_partitioner = get_coreml_partitioner(
529536
args.use_kv_cache and args.coreml_enable_state,
537+
args.coreml_preserve_sdpa,
530538
args.embedding_quantize,
531539
args.pt2e_quantize,
532540
args.coreml_quantize,
@@ -742,7 +750,7 @@ def _load_llama_model(
742750
)
743751

744752

745-
def _get_source_transforms(
753+
def _get_source_transforms( # noqa
746754
modelname: str, dtype_override: Optional[DType], args
747755
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
748756
transforms = []
@@ -795,10 +803,17 @@ def _get_source_transforms(
795803
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
796804
transforms.append(convert_linear_to_conv2d)
797805

798-
elif args.coreml or args.mps:
799-
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
806+
elif args.mps:
807+
# Currently mps doesn't support sdpa op, use the simpler decomposition
800808
# to get free perf gain.
801809
transforms.append(replace_sdpa_with_simple_sdpa)
802810
transforms.append(replace_causal_mask)
803811

812+
elif args.coreml:
813+
if args.coreml_preserve_sdpa:
814+
transforms.append(replace_sdpa_with_coreml_sdpa)
815+
else:
816+
transforms.append(replace_sdpa_with_simple_sdpa)
817+
transforms.append(replace_kv_cache_with_coreml_kv_cache)
818+
804819
return transforms

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,136 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
195195
return module
196196

197197

198+
@torch.library.custom_op("coreml::sdpa", mutates_args=())
199+
def sdpa(
200+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
201+
) -> torch.Tensor:
202+
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
203+
return torch.ops.aten.scaled_dot_product_attention.default(
204+
q, k, v, attn_mask=attn_mask
205+
)
206+
207+
208+
@torch.library.register_fake("coreml::sdpa")
209+
def _(
210+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
211+
) -> torch.Tensor:
212+
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
213+
expected_shape = list(q.shape)
214+
expected_shape[-1] = v.shape[-1]
215+
return q.new_empty(expected_shape)
216+
217+
218+
class SDPACoreML(torch.nn.Module):
219+
"""Similar to SDPASimple, but with coreml custom op to do SDPA calculation."""
220+
221+
def __init__(
222+
self,
223+
kv_cache: KVCache,
224+
dim: int,
225+
head_dim: int,
226+
n_rep: int,
227+
):
228+
super().__init__()
229+
self.kv_cache = kv_cache
230+
self.dim = dim
231+
self.head_dim = head_dim
232+
self.n_rep = n_rep
233+
234+
def forward(
235+
self,
236+
input_pos: torch.Tensor,
237+
q: torch.Tensor,
238+
k: torch.Tensor,
239+
v: torch.Tensor,
240+
bsz,
241+
seqlen,
242+
mask,
243+
):
244+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
245+
k = k.transpose(1, 2)
246+
v = v.transpose(1, 2)
247+
248+
k, v = self.kv_cache.update(input_pos, k, v)
249+
attn_mask = mask[None, None, input_pos]
250+
251+
if self.n_rep > 1:
252+
k = k.repeat_interleave(self.n_rep, dim=1)
253+
v = v.repeat_interleave(self.n_rep, dim=1)
254+
255+
y = torch.ops.coreml.sdpa(q, k, v, attn_mask)
256+
257+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
258+
259+
260+
def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module):
261+
for name, child in module.named_children():
262+
if isinstance(child, SDPA):
263+
setattr(
264+
module,
265+
name,
266+
SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep),
267+
)
268+
else:
269+
replace_sdpa_with_coreml_sdpa(child)
270+
return module
271+
272+
273+
class KVCacheCoreML(torch.nn.Module):
274+
"""
275+
Rather than k_out[:, :, input_pos] = k_val, use torch.ops.aten.index_put_,
276+
which can directly translate to CoreML iOS18.silce_update
277+
"""
278+
279+
def __init__(
280+
self,
281+
max_batch_size: int,
282+
max_seq_length: int,
283+
n_heads: int,
284+
head_dim: int,
285+
dtype=torch.float32,
286+
):
287+
super().__init__()
288+
self.max_seq_length = max_seq_length
289+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
290+
291+
self.max_batch_size = max_batch_size
292+
self.n_heads = n_heads
293+
self.head_dim = head_dim
294+
self.register_buffer(
295+
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
296+
)
297+
self.register_buffer(
298+
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
299+
)
300+
301+
def update(
302+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
303+
) -> Tuple[torch.Tensor, torch.Tensor]:
304+
k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
305+
v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
306+
return k_out, v_out
307+
308+
309+
def replace_kv_cache_with_coreml_kv_cache(module: torch.nn.Module):
310+
for name, child in module.named_children():
311+
if isinstance(child, KVCache):
312+
setattr(
313+
module,
314+
name,
315+
KVCacheCoreML(
316+
child.max_batch_size,
317+
child.max_seq_length,
318+
child.n_heads,
319+
child.head_dim,
320+
child.k_cache.dtype,
321+
),
322+
)
323+
else:
324+
replace_kv_cache_with_coreml_kv_cache(child)
325+
return module
326+
327+
198328
class KVCacheSimple(torch.nn.Module):
199329
def __init__(
200330
self,

extension/llm/export/partitioner_lib.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def get_mps_partitioner(use_kv_cache: bool = False):
5757

5858
def get_coreml_partitioner(
5959
enable_state: bool = False,
60+
preserve_sdpa: bool = True,
6061
embedding_quantize: Optional[str] = None,
6162
pt2e_quantize: Optional[str] = None,
6263
coreml_quantize: Optional[str] = None,
@@ -78,6 +79,9 @@ def get_coreml_partitioner(
7879
# In Core ML, stateful execution is introduced in iOS 18
7980
if enable_state:
8081
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
82+
# In Core ML, sdpa op is introduced in iOS 18
83+
if preserve_sdpa:
84+
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
8185
# In Core ML, quantization is introduced in iOS 16
8286
if embedding_quantize is not None or pt2e_quantize is not None:
8387
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16)

0 commit comments

Comments
 (0)