Skip to content

Commit 96586a7

Browse files
committed
Update base for Update on "qnn end to end flow"
Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a boy named Tim. Tim had a pet dog named Max. Max was a big, strong dog. They liked to play and run in the park. One day, Tim and Max went to the park to play. They saw a cat. The cat was up in a tree. Max wanted to help the cat. He tried to climb the tree, but he could not. Then, something unexpected happened. Max started to climb the tree! He was very strong. Max helped the cat come down. The cat was happy. Tim was so proud of his pet. ``` Stories model is too small and sensitive to qunatization. Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/) [ghstack-poisoned]
2 parents c44d8ef + 2c467dd commit 96586a7

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
614614
bitwidth = int(bitwidth)
615615
transforms.append(
616616
lambda model: EmbeddingQuantHandler(
617-
model, bitwidth=bitwidth, group_size=group_size
617+
model,
618+
bitwidth=bitwidth,
619+
group_size=group_size,
620+
packed=(bitwidth == 4),
618621
).quantized_model()
619622
)
620623

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def embedding_byte_dtype_out_meta(
189189

190190
quantized_decomposed_lib.define(
191191
"embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
192-
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)",
192+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
193193
)
194194

195195

0 commit comments

Comments
 (0)