Skip to content

Commit e46b6b6

Browse files
committed
Update on "qnn end to end flow"
Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a boy named Tim. Tim had a pet dog named Max. Max was a big, strong dog. They liked to play and run in the park. One day, Tim and Max went to the park to play. They saw a cat. The cat was up in a tree. Max wanted to help the cat. He tried to climb the tree, but he could not. Then, something unexpected happened. Max started to climb the tree! He was very strong. Max helped the cat come down. The cat was happy. Tim was so proud of his pet. ``` Stories model is too small and sensitive to qunatization. Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/) [ghstack-poisoned]
2 parents 9f05c5c + c44d8ef commit e46b6b6

File tree

4 files changed

+18
-8
lines changed

4 files changed

+18
-8
lines changed

examples/models/llama2/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ For Llama3, we can use the same process. Note that it's only supported in the Ex
2424
## Quantization:
2525
We employed 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. For more information refer to this [page](https://github.com/pytorch-labs/ao/).
2626

27-
We evaluated UncycloText perplexity using [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness). Below are the results for two different groupsizes.
27+
We evaluated UncycloText perplexity using [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness). Below are the results for two different groupsizes, with max_seq_len 2048, and 1000 samples.
2828

29-
|Llama 2 | Baseline (FP32) | Groupwise 4-bit (128) | Groupwise 4-bit (256)
29+
|Model | Baseline (FP32) | Groupwise 4-bit (128) | Groupwise 4-bit (256)
3030
|--------|-----------------| ---------------------- | ---------------
31-
|Uncyclotext Perplexity | 9.16 | 10.2 | 10.7
31+
|Llama 2 7B | 9.2 | 10.2 | 10.7
32+
|Llama 3 8B | 7.9 | 9.4 | 9.7
3233

3334
Note that groupsize less than 128 was not enabled, since such model were still too large. This is because our current efforts have focused on enabling FP32 and support for FP16 is under way. What this implies for model size is that 1) embedding table is in FP32 and 2) quantized weights scales are FP32.
3435

examples/models/llama2/eval_llama_lib.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,11 @@ def __init__(
4242
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
4343
max_seq_length: Optional[int] = None,
4444
):
45-
super().__init__()
45+
device = "cuda" if torch.cuda.is_available() else "cpu"
46+
super().__init__(device=device)
4647
self._model = model
4748
self._tokenizer = tokenizer
48-
self._device = (
49-
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
50-
)
49+
self._device = torch.device(device)
5150
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
5251

5352
@property

examples/models/llama2/export_llama_lib.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,15 +685,20 @@ def _export_llama(modelname, args) -> str: # noqa: C901
685685
# more custom quantization are supported including 16a4w etc. default to 8bit quantized
686686
custom_annotations = ()
687687
if quant_config == "8a8w":
688+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
688689
quant_dtype = QuantDtype.use_8a8w
689690
pass
690691
elif quant_config == "16a16w":
692+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
691693
quant_dtype = QuantDtype.use_16a16w
692694
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
695+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
693696
qnn_quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config())
694697
elif quant_config == "16a4w":
698+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
695699
quant_dtype = QuantDtype.use_16a4w
696700
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
701+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
697702
qnn_quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config())
698703
qnn_quantizer.set_per_channel_weight_dtype(
699704
weight_dtype_for_16bit_act="int4"
@@ -822,16 +827,18 @@ def _export_llama(modelname, args) -> str: # noqa: C901
822827
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
823828
)
824829

825-
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
826830
use_fp16 = True
827831
skip_node_op_set = {}
828832
if args.pt2e_quantize:
829833
use_fp16 = False
830834
# TODO: fix the lowering error without skipping nodes
835+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
831836
if quant_dtype == QuantDtype.use_8a8w:
832837
raise NotImplementedError("8a8w for llama is still under development")
838+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
833839
elif quant_dtype == QuantDtype.use_16a16w:
834840
raise NotImplementedError("16a16w for llama is still under development")
841+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
835842
elif quant_dtype == QuantDtype.use_16a4w:
836843
raise NotImplementedError("16a4w for llama is still under development")
837844
partitioners.append(
@@ -841,6 +848,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
841848
generate_qnn_executorch_compiler_spec(
842849
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
843850
soc_model=QcomChipset.SM8650, # default to SM8650
851+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
844852
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
845853
debug=False,
846854
saver=False,

examples/models/llama3/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Summary
2+
For Llama3, use the same example code, minus tokenizer, as Llama2. Please see the ../llama2/README.md for details.

0 commit comments

Comments
 (0)