Skip to content

Commit 957259e

Browse files
authored
Fix hardcoded rope_scale factor to 32 for Llama 3.2
Differential Revision: D67061188 Pull Request resolved: #7272
1 parent 59df3fe commit 957259e

File tree

7 files changed

+32
-12
lines changed

7 files changed

+32
-12
lines changed

backends/vulkan/docs/android_demo.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ partially lower the Llama model to Vulkan.
5959
# The files will usually be downloaded to ~/.llama
6060
python -m examples.models.llama.export_llama \
6161
--disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \
62+
--model "llama3_2" \
6263
-c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \
6364
-p ~/.llama/checkpoints/Llama3.2-1B/params.json \
6465
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ In this demo app, we support text-only inference with up-to-date Llama models an
5656
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
5757
* Export Llama model and generate .pte file as below:
5858
```
59-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
59+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
6060
```
6161

6262
### For Llama 3.2 1B and 3B QAT+LoRA models
6363
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
6464
* Export Llama model and generate .pte file as below:
6565
```
66-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
66+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
6767
```
6868

6969
### For Llama 3.2 1B and 3B BF16 models
@@ -72,7 +72,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
7272
* Export Llama model and generate .pte file as below:
7373

7474
```
75-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
75+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
7676
```
7777

7878
For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).

examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ sh examples/models/llama/install_requirements.sh
4848
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
4949
* Export Llama model and generate .pte file as below:
5050
```
51-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
51+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
5252
```
5353

5454
### For Llama 3.2 1B and 3B QAT+LoRA models
5555
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
5656
* Export Llama model and generate .pte file as below:
5757
```
58-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
58+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
5959
```
6060

6161
### For Llama 3.2 1B and 3B BF16 models
@@ -64,7 +64,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
6464
* Export Llama model and generate .pte file as below:
6565

6666
```
67-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
67+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
6868
```
6969

7070
For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).

examples/models/llama/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth
168168
LLAMA_PARAMS=path/to/params.json
169169
170170
python -m examples.models.llama.export_llama \
171+
--model "llama3_2" \
171172
--checkpoint "${LLAMA_CHECKPOINT:?}" \
172173
--params "${LLAMA_PARAMS:?}" \
173174
-kv \
@@ -189,6 +190,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
189190
LLAMA_PARAMS=path/to/spinquant/params.json
190191
191192
python -m examples.models.llama.export_llama \
193+
--model "llama3_2" \
192194
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
193195
--params "${LLAMA_PARAMS:?}" \
194196
--use_sdpa_with_kv_cache \
@@ -214,6 +216,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth
214216
LLAMA_PARAMS=path/to/qlora/params.json
215217
216218
python -m examples.models.llama.export_llama \
219+
--model "llama3_2" \
217220
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
218221
--params "${LLAMA_PARAMS:?}" \
219222
-qat \

examples/models/llama/llama_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class ModelArgs:
113113
)
114114
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
115115
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
116+
rope_scale_factor: int = 8
116117
# Additional Model Metadata needed at runtime
117118
bos_idx: int = 1
118119
eos_idx: int = 3
@@ -155,7 +156,9 @@ def __init__(self, params: ModelArgs):
155156
self.precompute_freqs_cis = hf_precompute_freqs_cis
156157
else:
157158
self.precompute_freqs_cis = partial(
158-
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
159+
precompute_freqs_cis,
160+
use_scaled=self.params.use_scaled_rope,
161+
scale_factor=self.params.rope_scale_factor,
159162
)
160163
freqs_cos, freqs_sin = self.precompute_freqs_cis(
161164
self.params.head_dim,

examples/models/llama/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ def __init__(self, **kwargs):
145145
enable_dynamic_shape=self.enable_dynamic_shape,
146146
**params,
147147
)
148+
149+
if model_args.use_scaled_rope:
150+
# Older models don't have use_scaled_rope configuration
151+
assert self.args.model not in ["llama2", "stories110m"]
152+
153+
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
154+
if self.args.model not in ["llama3", "llama3_1"]:
155+
model_args.rope_scale_factor = 32
156+
148157
if kwargs.get("verbose", False):
149158
print("============= weights ================")
150159
print("{key} : {weights.numel()} : {weights.size()}")

examples/models/llama/rope.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@
88
# Different RoPE implementations
99

1010
import math
11-
from typing import Tuple
11+
from typing import Optional, Tuple
1212

1313
import torch
1414

1515
# ======================== Stock Implementation ========================
1616

1717

18-
def apply_scaling(freqs: torch.Tensor):
18+
def apply_scaling(freqs: torch.Tensor, scale_factor: int):
1919
# Values obtained from grid search
20-
scale_factor = 8
2120
low_freq_factor = 1
2221
high_freq_factor = 4
2322
old_context_len = 8192 # original llama3 length
@@ -41,14 +40,19 @@ def apply_scaling(freqs: torch.Tensor):
4140

4241

4342
def precompute_freqs_cis(
44-
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
43+
dim: int,
44+
end: int,
45+
theta: float = 10000.0,
46+
use_scaled: bool = False,
47+
scale_factor: Optional[int] = None,
4548
):
4649
freqs = 1.0 / (
4750
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
4851
)
4952
t = torch.arange(end, device=freqs.device) # pyre-ignore
5053
if use_scaled:
51-
freqs = apply_scaling(freqs) # pyre-ignore
54+
assert scale_factor is not None
55+
freqs = apply_scaling(freqs, scale_factor) # pyre-ignore
5256
freqs = torch.outer(t, freqs).float()
5357
freqs_cos = torch.cos(freqs)
5458
freqs_sin = torch.sin(freqs)

0 commit comments

Comments
 (0)