Skip to content

Commit a32e729

Browse files
cccclaifacebook-github-bot
authored andcommitted
apply simple sdpa to coreml/mps backend (#3660)
Summary: Pull Request resolved: #3660 coreml and mps doesn't support sdpa at the moment, use simple sdpa to have a simpler decomposition. Observer 1.5x faster on emulator. Reviewed By: shoumikhin, kirklandsign Differential Revision: D57476985 fbshipit-source-id: 2dbcad1a6e8b744e0a95d60fcd740369d665eab0
1 parent 39834b1 commit a32e729

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,12 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
346346
if args.use_sdpa_with_kv_cache:
347347
transforms.append(replace_sdpa_with_custom_op)
348348

349-
if args.qnn and args.use_kv_cache:
350-
transforms.append(replace_sdpa_with_simple_sdpa)
351-
transforms.append(replace_causal_mask)
349+
if args.use_kv_cache:
350+
if args.qnn or args.coreml or args.mps:
351+
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
352+
# to get free perf gain.
353+
transforms.append(replace_sdpa_with_simple_sdpa)
354+
transforms.append(replace_causal_mask)
352355
return (
353356
load_llama_model(
354357
modelname=modelname,

0 commit comments

Comments
 (0)