Skip to content

Commit 7e35fe6

Browse files
committed
Update on "[ET-VK] Replace Uniform buffers with push constants for view op"
This diff replaces uniform buffers with push constants for view op in the Vulkan backend of Executorch. The changes include updating the GLSL code to use push constants instead of uniform buffers and updating the C++ code to pass the sizes as push constants to the shader. Differential Revision: [D66733658](https://our.internmc.facebook.com/intern/diff/D66733658/) [ghstack-poisoned]
2 parents 9fde95a + c352184 commit 7e35fe6

File tree

9 files changed

+37
-17
lines changed

9 files changed

+37
-17
lines changed

backends/arm/test/ops/test_layer_norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ def test_layer_norm_tosa_BI(
157157

158158
# Numerical issues on FVP likely due to mul op, MLETORCH-521
159159
# Skip tests that require transposes.
160-
@parameterized.expand(test_data_suite[:-2])
160+
@parameterized.expand(test_data_suite)
161161
@unittest.expectedFailure
162-
def test_layer_norm_u55_BI(
162+
def test_layer_norm_u55_BI_xfails(
163163
self,
164164
test_name: str,
165165
test_data: torch.Tensor,
@@ -171,7 +171,8 @@ def test_layer_norm_u55_BI(
171171

172172
# Numerical issues on FVP likely due to mul op, MLETORCH-521
173173
@parameterized.expand(test_data_suite[:-2])
174-
def test_layer_norm_u85_BI_fvp(
174+
@unittest.expectedFailure
175+
def test_layer_norm_u85_BI_xfails(
175176
self,
176177
test_name: str,
177178
test_data: torch.Tensor,
@@ -182,7 +183,6 @@ def test_layer_norm_u85_BI_fvp(
182183
)
183184

184185
@parameterized.expand(test_data_suite[-2:])
185-
@unittest.skip # Flaky
186186
def test_layer_norm_u85_BI(
187187
self,
188188
test_name: str,

backends/vulkan/docs/android_demo.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ partially lower the Llama model to Vulkan.
5959
# The files will usually be downloaded to ~/.llama
6060
python -m examples.models.llama.export_llama \
6161
--disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \
62+
--model "llama3_2" \
6263
-c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \
6364
-p ~/.llama/checkpoints/Llama3.2-1B/params.json \
6465
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

examples/arm/setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ tosa_reference_model_rev="c5570b79e90c3a36ab8c4ddb8ee3fbc2cd3f7c38"
9292

9393
# vela
9494
vela_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u-vela"
95-
vela_rev="a08fc18780827b5fefc814dd0162ee6317ce0ae7"
95+
vela_rev="5427dc7e9c1a4c7d554163290faeea75f168772d"
9696

9797
########
9898
### Mandatory user args

examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ In this demo app, we support text-only inference with up-to-date Llama models an
5656
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
5757
* Export Llama model and generate .pte file as below:
5858
```
59-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
59+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
6060
```
6161

6262
### For Llama 3.2 1B and 3B QAT+LoRA models
6363
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
6464
* Export Llama model and generate .pte file as below:
6565
```
66-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
66+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
6767
```
6868

6969
### For Llama 3.2 1B and 3B BF16 models
@@ -72,7 +72,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
7272
* Export Llama model and generate .pte file as below:
7373

7474
```
75-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
75+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
7676
```
7777

7878
For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).

examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ sh examples/models/llama/install_requirements.sh
4848
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
4949
* Export Llama model and generate .pte file as below:
5050
```
51-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
51+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
5252
```
5353

5454
### For Llama 3.2 1B and 3B QAT+LoRA models
5555
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
5656
* Export Llama model and generate .pte file as below:
5757
```
58-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
58+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
5959
```
6060

6161
### For Llama 3.2 1B and 3B BF16 models
@@ -64,7 +64,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
6464
* Export Llama model and generate .pte file as below:
6565

6666
```
67-
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
67+
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
6868
```
6969

7070
For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).

examples/models/llama/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth
168168
LLAMA_PARAMS=path/to/params.json
169169
170170
python -m examples.models.llama.export_llama \
171+
--model "llama3_2" \
171172
--checkpoint "${LLAMA_CHECKPOINT:?}" \
172173
--params "${LLAMA_PARAMS:?}" \
173174
-kv \
@@ -189,6 +190,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
189190
LLAMA_PARAMS=path/to/spinquant/params.json
190191
191192
python -m examples.models.llama.export_llama \
193+
--model "llama3_2" \
192194
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
193195
--params "${LLAMA_PARAMS:?}" \
194196
--use_sdpa_with_kv_cache \
@@ -214,6 +216,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth
214216
LLAMA_PARAMS=path/to/qlora/params.json
215217
216218
python -m examples.models.llama.export_llama \
219+
--model "llama3_2" \
217220
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
218221
--params "${LLAMA_PARAMS:?}" \
219222
-qat \

examples/models/llama/llama_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class ModelArgs:
113113
)
114114
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
115115
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
116+
rope_scale_factor: int = 8
116117
# Additional Model Metadata needed at runtime
117118
bos_idx: int = 1
118119
eos_idx: int = 3
@@ -155,7 +156,9 @@ def __init__(self, params: ModelArgs):
155156
self.precompute_freqs_cis = hf_precompute_freqs_cis
156157
else:
157158
self.precompute_freqs_cis = partial(
158-
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
159+
precompute_freqs_cis,
160+
use_scaled=self.params.use_scaled_rope,
161+
scale_factor=self.params.rope_scale_factor,
159162
)
160163
freqs_cos, freqs_sin = self.precompute_freqs_cis(
161164
self.params.head_dim,

examples/models/llama/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ def __init__(self, **kwargs):
145145
enable_dynamic_shape=self.enable_dynamic_shape,
146146
**params,
147147
)
148+
149+
if model_args.use_scaled_rope:
150+
# Older models don't have use_scaled_rope configuration
151+
assert self.args.model not in ["llama2", "stories110m"]
152+
153+
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
154+
if self.args.model not in ["llama3", "llama3_1"]:
155+
model_args.rope_scale_factor = 32
156+
148157
if kwargs.get("verbose", False):
149158
print("============= weights ================")
150159
print("{key} : {weights.numel()} : {weights.size()}")

examples/models/llama/rope.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@
88
# Different RoPE implementations
99

1010
import math
11-
from typing import Tuple
11+
from typing import Optional, Tuple
1212

1313
import torch
1414

1515
# ======================== Stock Implementation ========================
1616

1717

18-
def apply_scaling(freqs: torch.Tensor):
18+
def apply_scaling(freqs: torch.Tensor, scale_factor: int):
1919
# Values obtained from grid search
20-
scale_factor = 8
2120
low_freq_factor = 1
2221
high_freq_factor = 4
2322
old_context_len = 8192 # original llama3 length
@@ -41,14 +40,19 @@ def apply_scaling(freqs: torch.Tensor):
4140

4241

4342
def precompute_freqs_cis(
44-
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
43+
dim: int,
44+
end: int,
45+
theta: float = 10000.0,
46+
use_scaled: bool = False,
47+
scale_factor: Optional[int] = None,
4548
):
4649
freqs = 1.0 / (
4750
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
4851
)
4952
t = torch.arange(end, device=freqs.device) # pyre-ignore
5053
if use_scaled:
51-
freqs = apply_scaling(freqs) # pyre-ignore
54+
assert scale_factor is not None
55+
freqs = apply_scaling(freqs, scale_factor) # pyre-ignore
5256
freqs = torch.outer(t, freqs).float()
5357
freqs_cos = torch.cos(freqs)
5458
freqs_sin = torch.sin(freqs)

0 commit comments

Comments
 (0)