Skip to content

Support quantized llama models #6486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Check out the [Getting Started](https://pytorch.org/executorch/stable/getting-st
Check out the examples of [Llama](./examples/models/llama/README.md), [Llava](./examples/models/llava/README.md) and [other models](./examples/README.md) running on edge devices using ExecuTorch.


**[UPDATE - 09/25]** We have added support for running [Llama 3.2 1B/3B](./examples/models/llama/README.md) models via ExecuTorch.
**[UPDATE - 10/24]** We have added support for running [Llama 3.2 Quantized 1B/3B](./examples/models/llama/README.md) models via ExecuTorch.

## Feedback

Expand Down
105 changes: 75 additions & 30 deletions examples/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This example demonstrates how to run [Llama models](https://www.llama.com/) on m
Here are supported models:

- Llama 3.2 1B and 3B
- Llama 3.2 Quantized 1B and 3B
- Llama 3.1 8B
- Llama 3 8B
- [Llama 2 7B](../llama2/README.md)
Expand All @@ -24,48 +25,62 @@ Please note that the models are subject to the [Llama 2 Acceptable Use Policy](h

# Results

## Llama 3.2 1B/3B
## Llama 3.2 1B/3B and quantized 1B/3B models

For Llama 3.2 1B/3B models, we have enabled the original bf16 format and quantization to 4-bit, using SpinQuant, for enhanced performance.
For Llama 3.2 1B/3B models, we have enabled the original BF16 format and quantization to 4-bit, using SpinQuant and QAT+LoRA, for enhanced performance.

### 1. Enablement
The quantized models were optimized primarily for Arm CPU architecture by leveraging XNNPACK and Kleidi AI library. Work is underway to specifically enable quantization on mobile accelerators for Llama 1B/3B.

### Enablement

We have successfully verified performance on the following devices: iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S24+, S22 and OnePlus 12 (featuring 16GB RAM).

Note, the Llama 3.2 3B unquantized bf16 model was only tested on the OnePlus 12, which has sufficient memory (16GB RAM) to support its size requirements.
Note, the Llama 3.2 3B unquantized BF16 model was only tested on the OnePlus 12, which has sufficient memory (16GB RAM) to support its size requirements.

### Quantization

### 2. Quantization
The 1B/3B models are sensitive to accuracy loss when regular post-training quantization (PTQ) is applied. To achieve a balance between accuracy, performance and memory, we utilized 4-bit quantization, using [SpinQuant](https://github.com/facebookresearch/SpinQuant/tree/main) and QAT+LoRA methods.

#### 2.1 SpinQuant
Our quantization scheme involves three parts, applicable to both methods:

The 1B/3B models are sensitive to accuracy loss when regular post-training quantization (PTQ) is applied. To achieve a balance between accuracy, performance and memory, we utilized 4-bit quantization with [SpinQuant](https://github.com/facebookresearch/SpinQuant/tree/main). With SpinQuant, we currently quantize 4-bit groupwise (with groupsize 32) weight, 8bit dynamic activation of all the linear layers of the model, except embedding and output layers. The embedding and output layers are quantized as 8-bit per-channel weight and 8-bit dynamic activation.
- We quantize all linear layers in all transformer blocks to a 4-bit groupwise scheme (with a group size of 32) for weights and 8-bit per-token dynamic quantization for activations.
- The classification layer is quantized to 8-bit per-channel for weight and 8-bit per token dynamic quantization for activation.
- We employ an 8-bit per channel quantization for embedding.

#### SpinQuant

The SpinQuant method takes the original weights and produces optimized quantized weights with minimal outliers, resulting in higher accuracy. This can be achieved without any finetuning of the weights and only requires 100 iterations on a single A100 node.

SpinQuant can generate quantized weights that are [compatible with ExecuTorch](https://github.com/facebookresearch/SpinQuant/tree/main?tab=readme-ov-file#3-export-to-executorch), specifically, it can be integrated with the existing optimized XNNPACK kernels (e.g., group-wise 4bit weight and 8bit dynamic activation). This allows developers to benefit from the higher accuracy of SpinQuant while also taking advantage of the strong performance of ExecuTorch acceleration.

### 3. Accuracy
#### Quantization-Aware Training and LoRA (QAT+LoRA)

Quantization-Aware Training (QAT) is employed to simulate the effects of quantization during the training of Llama-3.2 models, enabling optimization of their performance in low precision environments. To initialize QAT, BF16 Llama-3.2 model checkpoints obtained after supervised fine-tuning (SFT) are utilized and an additional full round of SFT training with QAT is performed. The backbone of the QAT model is then frozen and another round of SFT is performed with low-rank adaptation (LoRA) adaptors applied to all layers within the transformer block. Meanwhile, the LoRA adaptors' weights and activations are maintained in BF16.

### Accuracy

Please see the [Llama 3.2 model card](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md) for accuracy evalations.

### 4. Performance:
### Performance

Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-5-run-benchmark-on) with prompt length of 64.
Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-4-run-benchmark-on-android-phone) with prompt length of 64. It is measured with KleidiAI library. KleidiAI is not enabled by default yet. Use `-DEXECUTORCH_XNNPACK_ENABLE_KLEIDI=ON` to enable it in the build.

|Model | decode (tokens/s) | prefill (tokens/s) | Memory size (RSS in MiB) |
|-------|------------------------ |------------------ | ------------------ |
|1B bf16 | 19.2 | 60.3 | 3,185 |
|1B SpinQuant | 50.2 | 260.5 | 1,921 |
|3B bf16 | 7.6 | 21.2 | 7,419 |
|3B SpinQuant | 19.7 | 89.7 | 3,726 |
|Model | Decode (tokens/s) | Time-to-first-token (sec) | Prefill (tokens/s) | Model size (PTE file size in MiB) | Memory size (RSS in MiB) |
|-------|------------------:|--------------------------:| ------------------:|----------------------------------:| ------------------------:|
|1B BF16 (baseline) | 19.2 | 1.0 | 60.3 | 2,358 | 3,185 |
|1B SpinQuant | 50.2 (2.6x) | 0.3 (-76.9%) | 260.5 (4.3x) | 1,083 (-54.1%) | 1,921 (-39.7%) |
|1B QAT+LoRA | 45.8 (2.4x) | 0.3 (-76.0%) | 252.0 (4.2x) | 1,127 (-52.2%) | 2,255 (-29.2%) |
|3B BF16 (baseline) | 7.6 | 3.0 | 21.2 | 6,129 | 7,419 |
|3B SpinQuant | 19.7 (2.6x) | 0.7 (-76.4%) | 89.7 (4.2x) | 2,435 (-60.3%) | 3,726 (-49.8%) |
|3B QAT+LoRA | 18.5 (2.4x) | 0.7 (-76.1%) | 88.8 (4.2x) | 2,529 (-58.7%) | 4,060 (-45.3%) |


<table>
<tr>
<td>
<img src="./Android3_2_1B_bf16.gif" width="300">
<br>
<em> Llama3.2 1B, unquantized, bf16 on Android phone. </em>
<em> Llama3.2 1B, unquantized, BF16 on Android phone. </em>
</td>
<td>
<img src="./Android3_2_3B_SpinQuant.gif" width="300">
Expand All @@ -80,15 +95,15 @@ Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The p
## Llama 3/3.1 8B
Since Llama 3 8B model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized (PTQ) model.

### 1. Enablement
### Enablement

For Llama 3 8B and Llama3.1 8B, we have verified so far on iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S24+ and OnePlus 12 (with 16GB RAM) by quantizing to 4bit.

### 2. Quantization
### Quantization

We employed PTQ 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. Due to Llama3's vocabulary size, we had to quantize embedding lookup table as well. For these results embedding lookup table was groupwise quantized with 4-bits and group size of 32.

### 3. Accuracy
### Accuracy

We evaluated UncycloText perplexity using [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness). Below are the results for two different groupsizes, with max_seq_length 2048, and limit 1000.

Expand All @@ -98,9 +113,9 @@ We evaluated UncycloText perplexity using [LM Eval](https://github.com/EleutherAI/l

Please note that LM Eval reports perplexity normalized by word count instead of token count. You may see different perplexity for UncycloText from other sources if they implement it differently. More details could be found [here](https://github.com/EleutherAI/lm-evaluation-harness/issues/2301).

### 4. Performance
### Performance

Llama 3 8B performance was measured on the Samsung Galaxy S22, S24, and OnePlus 12 devices. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-5-run-benchmark-on).
Llama 3 8B performance was measured on the Samsung Galaxy S22, S24, and OnePlus 12 devices. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-4-run-benchmark-on-android-phone).

|Device | Groupwise 4-bit (128) | Groupwise 4-bit (256)
|--------| ---------------------- | ---------------
Expand Down Expand Up @@ -137,9 +152,11 @@ Llama 3 8B performance was measured on the Samsung Galaxy S22, S24, and OnePlus

1. Download `consolidated.00.pth`, `params.json` and `tokenizer.model` from [Llama website](https://www.llama.com/llama-downloads/) or [Hugging Face](https://huggingface.co/meta-llama/Llama-3.2-1B). For chat use-cases, download the instruct models.

2. Export model and generate `.pte` file. Use original bfloat16 version, without any quantization.
2. Export model and generate `.pte` file.

- Use **original BF16** version, without any quantization.
```
# No quantization
# Set these paths to point to the downloaded files
LLAMA_CHECKPOINT=path/to/checkpoint.pth
LLAMA_PARAMS=path/to/params.json
Expand All @@ -155,20 +172,22 @@ python -m examples.models.llama.export_llama \
--output_name="llama3_2.pte"
```

Optionally, we can apply SpinQuant to quantize the model without sacrifacing too much accuracy loss.

To use SpinQuant, follow its [instruction](https://github.com/facebookresearch/SpinQuant/tree/main?tab=readme-ov-file#3-export-to-executorch) for exporting checkpoint to ExecuTorch and then export the SpinQuant checkpoint.
- To use **SpinQuant**, here are two ways:
- Download directly from [Llama website](https://www.llama.com/llama-downloads). The model weights are prequantized and can be exported to `pte` file directly.
- Follow its [instruction](https://github.com/facebookresearch/SpinQuant/tree/main?tab=readme-ov-file#3-export-to-executorch) for exporting checkpoint to ExecuTorch and then export the SpinQuant checkpoint.

```
# SpinQuant
# Set these paths to point to the exported files
LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
LLAMA_PARAMS=path/to/params.json
LLAMA_PARAMS=path/to/spinquant/params.json

python -m examples.models.llama.export_llama \
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
--params "${LLAMA_PARAMS:?}" \
--use_sdpa_with_kv_cache \
-X \
--xnnpack-extended-ops \
--preq_mode 8da4w_output_8da8w \
--preq_group_size 32 \
--max_seq_length 2048 \
Expand All @@ -180,6 +199,32 @@ python -m examples.models.llama.export_llama \
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
```

- To use **QAT+LoRA**, download directly from [Llama website](https://www.llama.com/llama-downloads). The model weights are prequantized and can be exported to `pte` file directly by:

```
# QAT+LoRA
# Set these paths to point to the exported files
LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth
LLAMA_PARAMS=path/to/qlora/params.json

python -m examples.models.llama.export_llama \
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
--params "${LLAMA_PARAMS:?}" \
-qat \
-lora 16 \
--preq_mode 8da4w_output_8da8w \
--preq_group_size 32 \
--preq_embedding_quantize 8,0 \
--use_sdpa_with_kv_cache \
-kv \
-X \
--xnnpack-extended-ops \
-d fp32 \
--max_seq_length 2048 \
--output_name "llama3_2.pte" \
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
```

### Option B: Download and export Llama 3 8B instruct model

You can export and run the original Llama 3 8B instruct model.
Expand All @@ -193,7 +238,7 @@ You can export and run the original Llama 3 8B instruct model.

Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size.

## Step 4: Run on your computer to validate
## Step 3: Run on your computer to validate

1. Build executorch with optimized CPU performance as follows. Build options available [here](https://github.com/pytorch/executorch/blob/main/CMakeLists.txt#L59).
```
Expand Down Expand Up @@ -236,7 +281,7 @@ Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the

To build for CoreML backend and validate on Mac, replace `-DEXECUTORCH_BUILD_XNNPACK=ON` with `-DEXECUTORCH_BUILD_COREML=ON`

## Step 5: Run benchmark on Android phone
## Step 4: Run benchmark on Android phone

**1. Build llama runner binary for Android**

Expand Down Expand Up @@ -301,7 +346,7 @@ adb push cmake-out-android/examples/models/llama/llama_main /data/local/tmp/llam

**2.3 Run model**
```
adb shell "cd /data/local/tmp/llama && ./llama_main --model_path <model.pte> --tokenizer_path <tokenizer.model> --prompt \"Once upon a time\" --seq_len 120"
adb shell "cd /data/local/tmp/llama && ./llama_main --model_path <model.pte> --tokenizer_path <tokenizer.model> --prompt \"What is the capital of France?\" --seq_len 120" --warmup=1
```
## Step 6: Build Mobile apps

Expand Down
Loading