Skip to content

Commit f5ec6cf

Browse files
committed
Update on "Add a simple sdpa"
Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where` ``` def forward(self, q, k, v): aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605); q = None aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2); aten_arange_start_step = None aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1); aten_arange_start_step_1 = None aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1); aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0); aten_sub_tensor = None aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default); aten_le_scalar = aten_full_default = None aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format) aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default); aten_logical_and_default = None aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default); aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]); k = None aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605); aten_permute_copy_default = None aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]); aten_mul_scalar = None aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]); aten_expand_copy_default = None aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]); aten_mul_scalar_1 = None aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]); aten_expand_copy_default_1 = None aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1); aten_view_copy_default = aten_view_copy_default_1 = None aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]); aten_bmm_default = None aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self); aten_view_copy_default_2 = aten_where_self = None aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False); aten_add_tensor = None aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]); aten__softmax_default = None aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]); aten_expand_copy_default_2 = None aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]); v = None aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]); aten_expand_copy_default_3 = None aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4); aten_view_copy_default_3 = aten_view_copy_default_4 = None aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]); aten_bmm_default_1 = None return (aten_view_copy_default_5,) ``` Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/) [ghstack-poisoned]
2 parents 8e99c7a + f25c479 commit f5ec6cf

File tree

12 files changed

+141
-375
lines changed

12 files changed

+141
-375
lines changed

.ci/scripts/build_llama_android.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ install_executorch_and_backend_lib() {
2626
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
2727
-DEXECUTORCH_BUILD_XNNPACK=ON \
2828
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
29+
-DEXECUTORCH_BUILD_QUANTIZED=ON \
2930
-DXNNPACK_ENABLE_ARM_BF16=OFF \
3031
-Bcmake-android-out .
3132

examples/models/llama2/README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ This example demonstrates how to run a [Llama 2](https://ai.meta.com/llama/) 7B
55
For Llama2, please refer to [the llama's github page](https://github.com/facebookresearch/llama) for details.
66
Pretrained parameters are not included in this repo. Users are suggested to download them through [the llama's download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/).
77

8-
# What is Llama 2?
8+
# What are Llama 2 and 3?
99
Llama is a family of large language models that uses publicly available data for training. These models are based on the transformer architecture, which allows it to process input sequences of arbitrary length and generate output sequences of variable length. One of the key features of Llama models is its ability to generate coherent and contextually relevant text. This is achieved through the use of attention mechanisms, which allow the model to focus on different parts of the input sequence as it generates output. Additionally, Llama models use a technique called “masked language modeling” to pre-train the model on a large corpus of text, which helps it learn to predict missing words in a sentence.
1010

1111
Llama models have shown to perform well on a variety of natural language processing tasks, including language translation, question answering, and text summarization and are also capable of generating human-like text, making Llama models a useful tool for creative writing and other applications where natural language generation is important.
@@ -17,7 +17,9 @@ Please note that the models are subject to the [acceptable use policy](https://g
1717

1818
# Results
1919

20-
Since 7B Llama2 model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized model.
20+
Since 7B Llama2 model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized model.
21+
22+
For Llama3, we can use the same process. Note that it's only supported in the ExecuTorch main branch.
2123

2224
## Quantization:
2325
We employed 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. For more information refer to this [page](https://github.com/pytorch-labs/ao/).
@@ -230,7 +232,7 @@ adb push cmake-out-android/examples/models/llama2/llama_main /data/local/tmp/lla
230232
231233
**2.3 Run model**
232234
```
233-
adb shell "cd /data/local/tmp/llama && ./llama_main --model_path <model.pte> --tokenizer_path <tokenizer.bin> --prompt "Once upon a time" --seq_len 120
235+
adb shell "cd /data/local/tmp/llama && ./llama_main --model_path <model.pte> --tokenizer_path <tokenizer.bin> --prompt \"Once upon a time\" --seq_len 120"
234236
```
235237
## Step 6: Build Mobile apps
236238
@@ -263,12 +265,13 @@ This example tries to reuse the Python code, with minimal modifications to make
263265
3. No dependencies on fairscale. The ColumnParallelLinear, ParallelEmbedding and training are not needed and supported in ExecuTorch.
264266
265267
266-
# Clean
267-
To clean your build:
268+
# Common Issues and Mitigations:
269+
- To clean your build:
268270
```
269271
git clean -xfd
270272
pip uninstall executorch
271273
./install_requirements.sh <options>
272274

273275
rm -rf cmake-out
274276
```
277+
- If you encounter `pthread` related issues during link time, add `pthread` in `target_link_libraries` in `CMakeLists.txt`

examples/models/llama2/builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def to_torch_dtype(self) -> torch.dtype:
6262

6363
def load_llama_model(
6464
*,
65+
modelname: str = "llama2",
6566
checkpoint: Optional[str] = None,
6667
checkpoint_dir: Optional[str] = None,
6768
params_path: str,
@@ -114,6 +115,7 @@ def load_llama_model(
114115

115116
return LlamaEdgeManager(
116117
model=model,
118+
modelname=modelname,
117119
weight_type=weight_type,
118120
dtype=dtype,
119121
use_kv_cache=use_kv_cache,
@@ -131,6 +133,7 @@ class LlamaEdgeManager:
131133
def __init__(
132134
self,
133135
model,
136+
modelname,
134137
weight_type,
135138
dtype,
136139
use_kv_cache,
@@ -139,6 +142,7 @@ def __init__(
139142
verbose: bool = False,
140143
):
141144
self.model = model
145+
self.modelname = modelname
142146
self.weight_type = weight_type
143147
self.dtype = dtype
144148
self.example_inputs = example_inputs

examples/models/llama2/export_llama_lib.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
3939
from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers
4040

41-
from .quantize import EmbeddingOnlyInt8QuantHandler, WeightOnlyInt8QuantHandler
41+
from .quantize import EmbeddingQuantHandler, WeightOnlyInt8QuantHandler
4242

4343

4444
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
@@ -145,6 +145,10 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
145145

146146

147147
class SDPASimple(torch.nn.Module):
148+
"""
149+
This is a simpler implementation of SDPA module defined in llama_transformer.py. Notice that it's
150+
an implementation including both some preprocessing logic and F.scaled_dot_product_attention.
151+
"""
148152
def __init__(
149153
self,
150154
kv_cache: KVCache,
@@ -168,6 +172,7 @@ def forward(
168172
seqlen,
169173
mask,
170174
):
175+
# The first few lines are the same as the original SDPA module.
171176
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
172177
k = k.transpose(1, 2)
173178
v = v.transpose(1, 2)
@@ -177,6 +182,11 @@ def forward(
177182

178183
k = k.repeat_interleave(self.n_rep, dim=1)
179184
v = v.repeat_interleave(self.n_rep, dim=1)
185+
186+
# Following is the different part. Instead of calling F.scaled_dot_product_attention,
187+
# we use the following implementation to avoid the decomposition from F.scaled_dot_product_attention,
188+
# as the decompostion is too expensive. The following will get rid of aten.full_like, aten.logical_not,
189+
# aten.scalar_tensor, aten.where and 2 extra aten.mul.
180190
scale_factor = 1 / math.sqrt(q.size(-1))
181191
attn_weight = q @ k.transpose(-2, -1) * scale_factor
182192
attn_weight += attn_mask
@@ -559,7 +569,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
559569
)
560570
params_path = canonical_path(args.params)
561571
output_dir_path = canonical_path(args.output_dir, dir=True)
562-
modelname = "llama2"
563572
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
564573

565574
# dtype override
@@ -613,7 +622,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
613622
group_size = int(group_size)
614623
bitwidth = int(bitwidth)
615624
transforms.append(
616-
lambda model: EmbeddingOnlyInt8QuantHandler(
625+
lambda model: EmbeddingQuantHandler(
617626
model, bitwidth=bitwidth, group_size=group_size
618627
).quantized_model()
619628
)
@@ -626,6 +635,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
626635

627636
return (
628637
load_llama_model(
638+
modelname=modelname,
629639
checkpoint=checkpoint_path,
630640
checkpoint_dir=checkpoint_dir,
631641
params_path=params_path,
@@ -673,6 +683,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901
673683
modelname, args
674684
).export_to_edge(quantizers)
675685

686+
modelname = builder_exported_to_edge.modelname
687+
676688
# to_backend
677689
partitioners = []
678690
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:

0 commit comments

Comments
 (0)