Skip to content

Commit db73e68

Browse files
committed
qnn end to end flow for stories model
Pull Request resolved: #3038 Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize qnn_8a8w -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, red apple hanging from a tree. She wanted to eat it, but it was too high up.. ``` Stories model is too small and sensitive to qunatization. ghstack-source-id: 222613601 @exported-using-ghexport Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/)
1 parent f83de2e commit db73e68

File tree

3 files changed

+65
-9
lines changed

3 files changed

+65
-9
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
3030
}
3131
QNN_TENSOR_TYPE_MAP = {
32+
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
3233
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
3334
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
3435
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,

backends/qualcomm/partition/common_defs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
exir_ops.edge.aten.clone.default,
1414
exir_ops.edge.aten.index.Tensor,
1515
exir_ops.edge.aten.full.default,
16+
exir_ops.edge.aten.slice_scatter.default,
17+
exir_ops.edge.aten.index_put.default,
1618
]
1719

1820
allow_list_operator = [

examples/models/llama2/export_llama_lib.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pkg_resources
2121
import torch
22+
import torch.nn.functional as F
2223
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
2324
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
2425
XnnpackDynamicallyQuantizedPartitioner,
@@ -354,6 +355,13 @@ def build_args_parser() -> argparse.ArgumentParser:
354355
parser.add_argument(
355356
"--pt2e_quantize",
356357
default=None,
358+
choices=[
359+
"xnnpack_dynamic",
360+
"xnnpack_dynamic_qc4",
361+
"qnn_8a8w",
362+
"qnn_16a16w",
363+
"qnn_16a4w",
364+
],
357365
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
358366
)
359367
parser.add_argument(
@@ -624,6 +632,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
624632
if args.use_sdpa_with_kv_cache:
625633
transforms.append(replace_sdpa_with_custom_op)
626634

635+
if args.qnn and args.use_kv_cache:
636+
transforms.append(replace_sdpa_with_simple_sdpa)
637+
transforms.append(replace_causal_mask)
627638
return (
628639
load_llama_model(
629640
checkpoint=checkpoint_path,
@@ -646,13 +657,16 @@ def _export_llama(modelname, args) -> str: # noqa: C901
646657
# export_to_edge
647658
pt2e_quant_params = _get_pt2e_quantization_params(args)
648659
quantizers = get_pt2e_quantizers(pt2e_quant_params, args)
649-
if args.qnn:
650-
assert (
651-
args.quantization_mode is None
652-
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
660+
quant_dtype = None
661+
if args.qnn and args.pt2e_quantize:
653662
try:
654663
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer`
655-
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
664+
from executorch.backends.qualcomm.quantizer.quantizer import (
665+
get_16a4w_qnn_ptq_config,
666+
get_default_16bit_qnn_ptq_config,
667+
QnnQuantizer,
668+
QuantDtype,
669+
)
656670

657671
# reset quantizers and pt2e_quant_params from xnnpack backend
658672
pt2e_quant_params = None
@@ -662,10 +676,36 @@ def _export_llama(modelname, args) -> str: # noqa: C901
662676
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
663677
)
664678

679+
backend, quant_config = args.pt2e_quantize.split("_")
680+
assert (
681+
backend == "qnn"
682+
), f"The quantization config is for backend {backend} instead of qnn."
665683
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
666684
qnn_quantizer = QnnQuantizer()
667685
# more custom quantization are supported including 16a4w etc. default to 8bit quantized
668686
custom_annotations = ()
687+
if quant_config == "8a8w":
688+
quant_dtype = QuantDtype.use_8a8w
689+
pass
690+
elif quant_config == "16a16w":
691+
quant_dtype = QuantDtype.use_16a16w
692+
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
693+
qnn_quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config())
694+
elif quant_config == "16a4w":
695+
quant_dtype = QuantDtype.use_16a4w
696+
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
697+
qnn_quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config())
698+
qnn_quantizer.set_per_channel_weight_dtype(
699+
weight_dtype_for_16bit_act="int4"
700+
)
701+
else:
702+
raise AssertionError(
703+
f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w."
704+
)
705+
706+
assert (
707+
args.quantization_mode is None
708+
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
669709
qnn_quantizer.add_custom_quant_annotations(custom_annotations)
670710
quantizers.append(qnn_quantizer)
671711

@@ -780,24 +820,37 @@ def _export_llama(modelname, args) -> str: # noqa: C901
780820
)
781821

782822
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
783-
backend_options = generate_htp_compiler_spec(use_fp16=False)
823+
use_fp16 = True
824+
skip_node_op_set = {}
825+
if args.pt2e_quantize:
826+
use_fp16 = False
827+
# TODO: fix the lowering error without skipping nodes
828+
if quant_dtype == QuantDtype.use_8a8w:
829+
skip_node_op_set = {
830+
"aten.unsqueeze_copy.default",
831+
"aten.permute_copy.default",
832+
}
833+
elif quant_dtype == QuantDtype.use_16a16w:
834+
raise NotImplementedError("16a16w for llama is still under development")
835+
elif quant_dtype == QuantDtype.use_16a4w:
836+
raise NotImplementedError("16a4w for llama is still under development")
784837
partitioners.append(
785838
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
786839
QnnPartitioner(
787840
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
788841
generate_qnn_executorch_compiler_spec(
789842
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
790843
soc_model=QcomChipset.SM8650, # default to SM8650
791-
backend_options=backend_options,
844+
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
792845
debug=False,
793846
saver=False,
794847
),
795848
skip_node_id_set={},
796-
skip_node_op_set={},
849+
skip_node_op_set=skip_node_op_set,
797850
)
798851
)
799852
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
800-
_transform(builder_exported_to_edge.export_program())
853+
_transform(builder_exported_to_edge.edge_manager.exported_program())
801854

802855
if args.generate_etrecord:
803856
if not builder_exported_to_edge.edge_manager:

0 commit comments

Comments
 (0)