Skip to content

Commit f25c479

Browse files
committed
Update base for Update on "Add a simple sdpa"
Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where` ``` def forward(self, q, k, v): aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605); q = None aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2); aten_arange_start_step = None aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1); aten_arange_start_step_1 = None aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1); aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0); aten_sub_tensor = None aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default); aten_le_scalar = aten_full_default = None aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format) aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default); aten_logical_and_default = None aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default); aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]); k = None aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605); aten_permute_copy_default = None aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]); aten_mul_scalar = None aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]); aten_expand_copy_default = None aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]); aten_mul_scalar_1 = None aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]); aten_expand_copy_default_1 = None aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1); aten_view_copy_default = aten_view_copy_default_1 = None aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]); aten_bmm_default = None aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self); aten_view_copy_default_2 = aten_where_self = None aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False); aten_add_tensor = None aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]); aten__softmax_default = None aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]); aten_expand_copy_default_2 = None aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]); v = None aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]); aten_expand_copy_default_3 = None aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4); aten_view_copy_default_3 = aten_view_copy_default_4 = None aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]); aten_bmm_default_1 = None return (aten_view_copy_default_5,) ``` Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/) [ghstack-poisoned]
2 parents e170aa5 + 523c2cb commit f25c479

File tree

12 files changed

+131
-375
lines changed

12 files changed

+131
-375
lines changed

.ci/scripts/build_llama_android.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ install_executorch_and_backend_lib() {
2626
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
2727
-DEXECUTORCH_BUILD_XNNPACK=ON \
2828
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
29+
-DEXECUTORCH_BUILD_QUANTIZED=ON \
2930
-DXNNPACK_ENABLE_ARM_BF16=OFF \
3031
-Bcmake-android-out .
3132

examples/models/llama2/README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ This example demonstrates how to run a [Llama 2](https://ai.meta.com/llama/) 7B
55
For Llama2, please refer to [the llama's github page](https://github.com/facebookresearch/llama) for details.
66
Pretrained parameters are not included in this repo. Users are suggested to download them through [the llama's download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/).
77

8-
# What is Llama 2?
8+
# What are Llama 2 and 3?
99
Llama is a family of large language models that uses publicly available data for training. These models are based on the transformer architecture, which allows it to process input sequences of arbitrary length and generate output sequences of variable length. One of the key features of Llama models is its ability to generate coherent and contextually relevant text. This is achieved through the use of attention mechanisms, which allow the model to focus on different parts of the input sequence as it generates output. Additionally, Llama models use a technique called “masked language modeling” to pre-train the model on a large corpus of text, which helps it learn to predict missing words in a sentence.
1010

1111
Llama models have shown to perform well on a variety of natural language processing tasks, including language translation, question answering, and text summarization and are also capable of generating human-like text, making Llama models a useful tool for creative writing and other applications where natural language generation is important.
@@ -17,7 +17,9 @@ Please note that the models are subject to the [acceptable use policy](https://g
1717

1818
# Results
1919

20-
Since 7B Llama2 model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized model.
20+
Since 7B Llama2 model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized model.
21+
22+
For Llama3, we can use the same process. Note that it's only supported in the ExecuTorch main branch.
2123

2224
## Quantization:
2325
We employed 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. For more information refer to this [page](https://github.com/pytorch-labs/ao/).
@@ -230,7 +232,7 @@ adb push cmake-out-android/examples/models/llama2/llama_main /data/local/tmp/lla
230232
231233
**2.3 Run model**
232234
```
233-
adb shell "cd /data/local/tmp/llama && ./llama_main --model_path <model.pte> --tokenizer_path <tokenizer.bin> --prompt "Once upon a time" --seq_len 120
235+
adb shell "cd /data/local/tmp/llama && ./llama_main --model_path <model.pte> --tokenizer_path <tokenizer.bin> --prompt \"Once upon a time\" --seq_len 120"
234236
```
235237
## Step 6: Build Mobile apps
236238
@@ -263,12 +265,13 @@ This example tries to reuse the Python code, with minimal modifications to make
263265
3. No dependencies on fairscale. The ColumnParallelLinear, ParallelEmbedding and training are not needed and supported in ExecuTorch.
264266
265267
266-
# Clean
267-
To clean your build:
268+
# Common Issues and Mitigations:
269+
- To clean your build:
268270
```
269271
git clean -xfd
270272
pip uninstall executorch
271273
./install_requirements.sh <options>
272274

273275
rm -rf cmake-out
274276
```
277+
- If you encounter `pthread` related issues during link time, add `pthread` in `target_link_libraries` in `CMakeLists.txt`

examples/models/llama2/builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def to_torch_dtype(self) -> torch.dtype:
6262

6363
def load_llama_model(
6464
*,
65+
modelname: str = "llama2",
6566
checkpoint: Optional[str] = None,
6667
checkpoint_dir: Optional[str] = None,
6768
params_path: str,
@@ -114,6 +115,7 @@ def load_llama_model(
114115

115116
return LlamaEdgeManager(
116117
model=model,
118+
modelname=modelname,
117119
weight_type=weight_type,
118120
dtype=dtype,
119121
use_kv_cache=use_kv_cache,
@@ -131,6 +133,7 @@ class LlamaEdgeManager:
131133
def __init__(
132134
self,
133135
model,
136+
modelname,
134137
weight_type,
135138
dtype,
136139
use_kv_cache,
@@ -139,6 +142,7 @@ def __init__(
139142
verbose: bool = False,
140143
):
141144
self.model = model
145+
self.modelname = modelname
142146
self.weight_type = weight_type
143147
self.dtype = dtype
144148
self.example_inputs = example_inputs

examples/models/llama2/export_llama_lib.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
3838
from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers
3939

40-
from .quantize import EmbeddingOnlyInt8QuantHandler, WeightOnlyInt8QuantHandler
40+
from .quantize import EmbeddingQuantHandler, WeightOnlyInt8QuantHandler
4141

4242

4343
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
@@ -485,7 +485,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
485485
)
486486
params_path = canonical_path(args.params)
487487
output_dir_path = canonical_path(args.output_dir, dir=True)
488-
modelname = "llama2"
489488
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
490489

491490
# dtype override
@@ -539,7 +538,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
539538
group_size = int(group_size)
540539
bitwidth = int(bitwidth)
541540
transforms.append(
542-
lambda model: EmbeddingOnlyInt8QuantHandler(
541+
lambda model: EmbeddingQuantHandler(
543542
model, bitwidth=bitwidth, group_size=group_size
544543
).quantized_model()
545544
)
@@ -552,6 +551,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
552551

553552
return (
554553
load_llama_model(
554+
modelname=modelname,
555555
checkpoint=checkpoint_path,
556556
checkpoint_dir=checkpoint_dir,
557557
params_path=params_path,
@@ -599,6 +599,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901
599599
modelname, args
600600
).export_to_edge(quantizers)
601601

602+
modelname = builder_exported_to_edge.modelname
603+
602604
# to_backend
603605
partitioners = []
604606
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:

examples/models/llama2/quantize.py

Lines changed: 92 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def dynamically_quantize_per_channel(
124124
return quant, scales, zero_points
125125

126126

127+
#########################################################################
128+
### QuantHandler API definition ###
129+
130+
127131
class QuantHandler:
128132
def __init__(self, mod):
129133
self.mod = mod
@@ -134,8 +138,15 @@ def create_quantized_state_dict(self) -> Dict: # "StateDict"
134138
def convert_for_runtime(self) -> nn.Module:
135139
pass
136140

141+
def quantized_model(self) -> nn.Module:
142+
model_updated_state_dict = self.create_quantized_state_dict()
143+
self.convert_for_runtime()
144+
self.mod.load_state_dict(model_updated_state_dict)
145+
return self.mod
137146

138-
##### Weight-only int8 per-channel quantized code ######
147+
148+
#########################################################################
149+
### Weight-only int8 per-channel quantized code ###
139150

140151

141152
def replace_linear_weight_only_int8_per_channel(module, node_type):
@@ -153,16 +164,17 @@ def replace_linear_weight_only_int8_per_channel(module, node_type):
153164
setattr(
154165
module,
155166
name,
156-
WeightOnlyInt8Linear(child.in_features, child.out_features),
167+
WeightOnlyInt8Linear("cpu", child.in_features, child.out_features),
157168
)
158169
else:
159170
replace_linear_weight_only_int8_per_channel(child, node_type)
160171

161172

162-
class WeightOnlyInt8QuantHandler:
173+
class WeightOnlyInt8QuantHandler(QuantHandler):
163174
def __init__(
164175
self,
165176
mod,
177+
device="cpu",
166178
*,
167179
node_type: str = "*",
168180
bitwidth: Optional[int] = None,
@@ -202,7 +214,7 @@ def create_quantized_state_dict(self) -> Dict:
202214
)
203215
):
204216
print(
205-
f"quantize {self.node_type} {fqn, mod} with groupsize {self.group_size}, bitwidth {self.bitwidth}"
217+
f"quantize {self.node_type} {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
206218
)
207219

208220
# print(f"initial weight shape {mod.weight.shape}")
@@ -219,7 +231,7 @@ def create_quantized_state_dict(self) -> Dict:
219231
)
220232

221233
cur_state_dict[f"{fqn}.weight"] = weight
222-
# squeeze makes groupsize=rowsize unidimensional
234+
# squeeze makes group_size=rowsize unidimensional
223235
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
224236

225237
return cur_state_dict
@@ -243,10 +255,10 @@ class WeightOnlyInt8Linear(torch.nn.Module):
243255

244256
def __init__(
245257
self,
258+
device,
246259
in_features: int,
247260
out_features: int,
248261
bias: bool = True,
249-
device=None,
250262
dtype=None,
251263
) -> None:
252264
super().__init__()
@@ -262,11 +274,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
262274
# return F.linear(input, self.weight.to(dtype=input.dtype)) * se...
263275

264276

265-
##### embedding table quantization ######
277+
#########################################################################
278+
##### embedding table quantization ######
266279

267280

268281
def replace_embedding_weight_only_grouped_int8_per_channel(
269-
module, bitwidth: int = 8, group_size: Optional[int] = None
282+
module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False
270283
):
271284
for name, child in module.named_children():
272285
# print(f"name: {name}")
@@ -277,25 +290,41 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
277290
module,
278291
name,
279292
QuantizedGroupEmbedding(
293+
device=device,
280294
vocab_size=child.weight.shape[0],
281295
embedding_dim=child.weight.shape[1],
282296
group_size=group_size,
297+
packed=packed,
283298
),
284299
)
285300
else:
286301
replace_embedding_weight_only_grouped_int8_per_channel(
287-
child, bitwidth, group_size
302+
child, device, bitwidth, group_size, packed
288303
)
289304

290305

291-
class EmbeddingOnlyInt8QuantHandler:
292-
def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None):
306+
class EmbeddingQuantHandler(QuantHandler):
307+
def __init__(
308+
self,
309+
mod,
310+
device="cpu",
311+
*,
312+
bitwidth: int = 8,
313+
group_size: Optional[int] = None,
314+
packed=False,
315+
):
316+
if isinstance(packed, str):
317+
packed = packed == "True"
293318
self.mod = mod
319+
self.device = device
294320
self.group_size = group_size
295321
self.bitwidth = bitwidth
322+
self.packed = packed
323+
if (bitwidth != 4) and packed:
324+
raise RuntimeError("pack only works with bitsize 4")
296325

297326
@torch.no_grad()
298-
def create_quantized_state_dict(self) -> Dict:
327+
def create_quantized_state_dict(self, packed=False) -> Dict:
299328
cur_state_dict = self.mod.state_dict()
300329

301330
if self.bitwidth == 4:
@@ -308,18 +337,14 @@ def create_quantized_state_dict(self) -> Dict:
308337
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
309338

310339
for fqn, mod in self.mod.named_modules():
311-
if (
312-
isinstance(mod, nn.Embedding)
313-
or isinstance(mod, fsEmbedding)
314-
or isinstance(mod, fsStandardEmbedding)
315-
):
340+
if isinstance(mod, nn.Embedding):
316341
# print("****")
317342
# print(f"Embedding identified: {fqn, mod}")
318343
# print(f"weights size: {mod.weight.size()}")
319344
# print(f"quantize {fqn}...")
320345

321346
print(
322-
f"quantize {fqn, mod} with groupsize {self.group_size}, bitwidth {self.bitwidth}"
347+
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
323348
)
324349
weight, scales, _ = dynamically_quantize_per_channel(
325350
mod.weight.float(),
@@ -330,21 +355,36 @@ def create_quantized_state_dict(self) -> Dict:
330355
scales_dtype=mod.weight.dtype,
331356
)
332357

358+
if packed:
359+
if weight.shape[-1] % 2 != 0:
360+
raise RuntimeError("automatic padding not implemented yet")
361+
362+
weight_range_shifted = weight.add(8).view(torch.uint8)
363+
weight_view = weight_range_shifted.view(
364+
weight.shape[0], weight.shape[1] // 2, 2
365+
)
366+
weight_even = weight_view[:, :, 0] * 16 # left shift 4
367+
weight_odd = weight_view[:, :, 1]
368+
weight_packed = weight_even + weight_odd
369+
weight = weight_packed
370+
371+
weight = weight.to(device=self.device)
372+
scales = scales.to(device=self.device)
333373
# Update state dict
334374
cur_state_dict[f"{fqn}.weight"] = weight
335-
# squeeze makes groupsize=rowsize unidimensional
375+
# squeeze makes group_size=rowsize unidimensional
336376
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
337377

338378
return cur_state_dict
339379

340380
def convert_for_runtime(self) -> nn.Module:
341381
replace_embedding_weight_only_grouped_int8_per_channel(
342-
self.mod, self.bitwidth, self.group_size
382+
self.mod, self.device, self.bitwidth, self.group_size, self.packed
343383
)
344384
return self.mod
345385

346386
def quantized_model(self) -> nn.Module:
347-
model_updated_state_dict = self.create_quantized_state_dict()
387+
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
348388
self.convert_for_runtime()
349389
self.mod.load_state_dict(model_updated_state_dict)
350390
return self.mod
@@ -353,39 +393,53 @@ def quantized_model(self) -> nn.Module:
353393
class QuantizedGroupEmbedding(torch.nn.Module):
354394
def __init__(
355395
self,
396+
device,
356397
vocab_size: int,
357398
embedding_dim: int,
358399
group_size: Optional[int] = None,
359-
device=None,
360400
dtype=torch.half,
401+
packed=False,
361402
) -> None:
362403
super().__init__()
363-
if group_size is None:
404+
if group_size is None or group_size == 0:
364405
group_size = embedding_dim
365406
self.group_size = group_size
366407
self.dtype = dtype
367-
self.register_buffer(
368-
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8)
369-
)
408+
self.packed = packed
409+
if not packed:
410+
self.register_buffer(
411+
"weight",
412+
torch.empty(
413+
(vocab_size, embedding_dim), dtype=torch.int8, device=device
414+
),
415+
)
416+
else: # packed
417+
self.register_buffer(
418+
"weight",
419+
torch.empty(
420+
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
421+
),
422+
)
370423
groups_per_row = (embedding_dim + group_size - 1) // group_size
371424
if groups_per_row > 1:
372425
self.register_buffer(
373-
"scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16)
426+
"scales",
427+
torch.ones(
428+
(vocab_size, groups_per_row), dtype=torch.float16, device=device
429+
),
374430
)
375431
else:
376432
self.register_buffer(
377-
"scales", torch.ones((vocab_size,), dtype=torch.float16)
433+
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
378434
)
379435

380436
@torch.no_grad()
381437
def forward(self, indices: torch.Tensor) -> torch.Tensor:
382-
return torch.ops.quantized_decomposed.embedding_byte.dtype(
383-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
384-
)
385-
386-
387-
# result_weights = self.weight.index_select(0, indices.view(-1))
388-
# result_scales = self.scales.index_select(0, indices.view(-1))
389-
#
390-
# r = result_weights.to(dtype=result_scales.dtype) * result_scales
391-
# return r.view(indices.size() + (-1,))
438+
if not self.packed: # 8bit
439+
return torch.ops.quantized_decomposed.embedding_byte.dtype(
440+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
441+
)
442+
else: # 4bit packed
443+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
444+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
445+
)

0 commit comments

Comments
 (0)