Skip to content

Commit 5d396b3

Browse files
committed
qnn end to end flow
Pull Request resolved: #3038 Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize qnn_8a8w -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a boy named Tim. Tim had a pet dog named Max. Max was a big, strong dog. They liked to play and run in the park. One day, Tim and Max went to the park to play. They saw a cat. The cat was up in a tree. Max wanted to help the cat. He tried to climb the tree, but he could not. Then, something unexpected happened. Max started to climb the tree! He was very strong. Max helped the cat come down. The cat was happy. Tim was so proud of his pet. ``` Stories model is too small and sensitive to qunatization. ghstack-source-id: 222473434 @exported-using-ghexport Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/)
1 parent 01446ac commit 5d396b3

File tree

3 files changed

+64
-10
lines changed

3 files changed

+64
-10
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
3030
}
3131
QNN_TENSOR_TYPE_MAP = {
32+
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
3233
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
3334
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
3435
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,

backends/qualcomm/partition/common_defs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
exir_ops.edge.aten.clone.default,
1414
exir_ops.edge.aten.index.Tensor,
1515
exir_ops.edge.aten.full.default,
16+
exir_ops.edge.aten.slice_scatter.default,
17+
exir_ops.edge.aten.index_put.default,
1618
]
1719

1820
allow_list_operator = [

examples/models/llama2/export_llama_lib.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pkg_resources
2121
import torch
22+
import torch.nn.functional as F
2223
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
2324
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
2425
XnnpackDynamicallyQuantizedPartitioner,
@@ -34,7 +35,6 @@
3435
from executorch.sdk.etrecord import generate_etrecord
3536
from executorch.util.activation_memory_profiler import generate_memory_trace
3637
from sentencepiece import SentencePieceProcessor
37-
from torch.nn import functional as F
3838

3939
from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
4040
from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers
@@ -337,6 +337,13 @@ def build_args_parser() -> argparse.ArgumentParser:
337337
parser.add_argument(
338338
"--pt2e_quantize",
339339
default=None,
340+
choices=[
341+
"xnnpack_dynamic",
342+
"xnnpack_dynamic_qc4",
343+
"qnn_8a8w",
344+
"qnn_16a16w",
345+
"qnn_16a4w",
346+
],
340347
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
341348
)
342349
parser.add_argument(
@@ -607,6 +614,8 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
607614
if args.use_sdpa_with_kv_cache:
608615
transforms.append(replace_sdpa_with_custom_op)
609616

617+
if args.qnn and args.use_kv_cache:
618+
transforms.append(replace_sdpa_with_simple_sdpa)
610619
return (
611620
load_llama_model(
612621
checkpoint=checkpoint_path,
@@ -629,13 +638,16 @@ def _export_llama(modelname, args) -> str: # noqa: C901
629638
# export_to_edge
630639
pt2e_quant_params = _get_pt2e_quantization_params(args)
631640
quantizers = get_pt2e_quantizers(pt2e_quant_params, args)
632-
if args.qnn:
633-
assert (
634-
args.quantization_mode is None
635-
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
641+
quant_dtype = None
642+
if args.qnn and args.pt2e_quantize:
636643
try:
637644
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer`
638-
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
645+
from executorch.backends.qualcomm.quantizer.quantizer import (
646+
get_16a4w_qnn_ptq_config,
647+
get_default_16bit_qnn_ptq_config,
648+
QnnQuantizer,
649+
QuantDtype,
650+
)
639651

640652
# reset quantizers and pt2e_quant_params from xnnpack backend
641653
pt2e_quant_params = None
@@ -645,10 +657,36 @@ def _export_llama(modelname, args) -> str: # noqa: C901
645657
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
646658
)
647659

660+
backend, quant_config = args.pt2e_quantize.split("_")
661+
assert (
662+
backend == "qnn"
663+
), f"The quantization config is for backend {backend} instead of qnn."
648664
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
649665
qnn_quantizer = QnnQuantizer()
650666
# more custom quantization are supported including 16a4w etc. default to 8bit quantized
651667
custom_annotations = ()
668+
if quant_config == "8a8w":
669+
quant_dtype = QuantDtype.use_8a8w
670+
pass
671+
elif quant_config == "16a16w":
672+
quant_dtype = QuantDtype.use_16a16w
673+
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
674+
qnn_quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config())
675+
elif quant_config == "16a4w":
676+
quant_dtype = QuantDtype.use_16a4w
677+
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
678+
qnn_quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config())
679+
qnn_quantizer.set_per_channel_weight_dtype(
680+
weight_dtype_for_16bit_act="int4"
681+
)
682+
else:
683+
raise AssertionError(
684+
f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w."
685+
)
686+
687+
assert (
688+
args.quantization_mode is None
689+
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
652690
qnn_quantizer.add_custom_quant_annotations(custom_annotations)
653691
quantizers.append(qnn_quantizer)
654692

@@ -763,24 +801,37 @@ def _export_llama(modelname, args) -> str: # noqa: C901
763801
)
764802

765803
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
766-
backend_options = generate_htp_compiler_spec(use_fp16=False)
804+
use_fp16 = True
805+
skip_node_op_set = {}
806+
if args.pt2e_quantize:
807+
use_fp16 = False
808+
# TODO: fix the lowering error without skipping nodes
809+
if quant_dtype == QuantDtype.use_8a8w:
810+
skip_node_op_set = {
811+
"aten.unsqueeze_copy.default",
812+
"aten.permute_copy.default",
813+
}
814+
elif quant_dtype == QuantDtype.use_16a16w:
815+
raise NotImplementedError("16a16w for llama is still under development")
816+
elif quant_dtype == QuantDtype.use_16a4w:
817+
raise NotImplementedError("16a4w for llama is still under development")
767818
partitioners.append(
768819
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
769820
QnnPartitioner(
770821
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
771822
generate_qnn_executorch_compiler_spec(
772823
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
773824
soc_model=QcomChipset.SM8650, # default to SM8650
774-
backend_options=backend_options,
825+
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
775826
debug=False,
776827
saver=False,
777828
),
778829
skip_node_id_set={},
779-
skip_node_op_set={},
830+
skip_node_op_set=skip_node_op_set,
780831
)
781832
)
782833
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
783-
_transform(builder_exported_to_edge.export_program())
834+
_transform(builder_exported_to_edge.edge_manager.exported_program())
784835

785836
if args.generate_etrecord:
786837
if not builder_exported_to_edge.edge_manager:

0 commit comments

Comments
 (0)