Skip to content

Commit 2726bdb

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
use --use_sdpa_with_kv_cache for 1B/3B bf16 (#5861)
Summary: Pull Request resolved: #5861 We should use this option during exporting 1B/3B models as bf16 because KVCache is always fp32. Otherwise, we see regressed performance for 1B/3B in bf16 format. ghstack-source-id: 246391007 Reviewed By: mergennachin Differential Revision: D63871048 fbshipit-source-id: 6b3ff80dbc689a04c70e2fcc5c98698bb74f899b
1 parent 478a9b6 commit 2726bdb

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ We have supported BFloat16 as a data type on the XNNPACK backend for Llama 3.2 1
7171
* Export Llama model and generate .pte file as below:
7272

7373
```
74-
python -m examples.models.llama2.export_llama --checkpoint <checkpoint.pth> --params <params.json> -kv -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2.pte"
74+
python -m examples.models.llama2.export_llama --checkpoint <checkpoint.pth> --params <params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2.pte"
7575
```
7676

7777
* Rename tokenizer for Llama 3.2 with command: `mv tokenizer.model tokenizer.bin`. We are updating the demo app to support tokenizer in original format directly.

examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ We have supported BFloat16 as a data type on the XNNPACK backend for Llama 3.2 1
5555
* Export Llama model and generate .pte file as below:
5656

5757
```
58-
python -m examples.models.llama2.export_llama --checkpoint <checkpoint.pth> --params <params.json> -kv -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2.pte"
58+
python -m examples.models.llama2.export_llama --checkpoint <checkpoint.pth> --params <params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2.pte"
5959
```
6060

6161
For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).
1.01 MB
Loading

examples/models/llama2/README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ We have verified running Llama 2 7B [mobile applications](#step-6-build-mobile-a
8585
### Llama 3.2 1B and 3B
8686
Llama 3.2 1B and 3B performance was measured on the OnePlus 12 device. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-5-run-benchmark-on) for generating 128 tokens.
8787

88-
|Model | 4bit(*) via SpinQuant
89-
|--------| ---------------
90-
|1B | 53.41 tokens/second |
91-
|3B | 22.98 tokens/second |
88+
|Model | bf16 | 4bit(*) via SpinQuant
89+
|--------| ---------------------- | ---------------
90+
|1B | 19.4 tokens/second | 53.41 tokens/second |
91+
|3B | 7.76 tokens/second | 22.98 tokens/second |
9292

9393
(*) With SpinQuant, we currently quantize 4-bit groupwise (with groupsize 32) weight, 8bit dynamic activation of all the linear layers of the model, except embedding and output layers. The embedding and output layers are quantized as 8-bit per-channel weight and 8-bit dynamic activation.
9494

@@ -142,7 +142,9 @@ LLAMA_PARAMS=path/to/params.json
142142
python -m examples.models.llama2.export_llama \
143143
--checkpoint "${LLAMA_CHECKPOINT:?}" \
144144
--params "${LLAMA_PARAMS:?}" \
145-
-kv -X \
145+
-kv \
146+
--use_sdpa_with_kv_cache \
147+
-X \
146148
-d bf16 \
147149
--metadata '{"append_eos_to_prompt": 0, "get_bos_id":128000, "get_eos_ids":[128009, 128001], "get_n_bos": 0, "get_n_eos": 0}' \
148150
--output_name="llama3_2.pte"

0 commit comments

Comments
 (0)