Skip to content

Commit 66bfd75

Browse files
authored
Qualcomm AI Engine Direct - Unify Llama2&Llama3 and Small Accuracy Improvement. (#7618)
Qualcomm AI Engine Direct - Unify Llama2 and Llama3
1 parent 108ec68 commit 66bfd75

File tree

24 files changed

+332
-2022
lines changed

24 files changed

+332
-2022
lines changed

backends/qualcomm/_passes/insert_requantize.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,9 @@ def _single_output_annotation(
8989
requantize_dict = n.meta.pop(QCOM_REQUANTIZE)
9090
# {quant_attr: user_node_name_list}
9191
group_quant_attr_dict = self._invert_dict(requantize_dict)
92-
# TODO: If users of the node contain output node,
93-
# we replace the node with to_copy op. However, it would
94-
# be problem when the node has multiple to_copy ops
95-
add_output = len(group_quant_attr_dict) == 1
9692

9793
for hashable_quant_attr, user_nodes in group_quant_attr_dict.items():
9894
user_nodes_copy = user_nodes.copy()
99-
if add_output:
100-
user_nodes_copy.append("output")
10195
self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy)
10296

10397
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,80 @@
1414
QuantizationConfig,
1515
)
1616
from executorch.exir.dialects._ops import ops as exir_ops
17-
from torch.ao.quantization.observer import MinMaxObserver
17+
from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver
1818
from torch.ao.quantization.quantizer import (
1919
QuantizationAnnotation,
20+
QuantizationSpec,
2021
SharedQuantizationSpec,
2122
)
2223
from torch.fx import Node
2324

2425

25-
def annotate_matmul_16a8w( # noqa: C901
26-
gm: torch.fx.GraphModule, traverse_input1=True
27-
) -> None:
26+
def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None:
27+
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
28+
input_qspec_map = {}
29+
input_act = node.args[0]
30+
input_spec = quantization_config.input_activation
31+
input_qspec_map[input_act] = input_spec
32+
33+
weight = node.args[1]
34+
input_qspec_map[weight] = quantization_config.weight
35+
36+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
37+
input_qspec_map=input_qspec_map,
38+
output_qspec=quantization_config.output_activation,
39+
_annotated=True,
40+
)
41+
42+
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
43+
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
44+
)
45+
for node in gm.graph.nodes:
46+
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
47+
if "nn_module_stack" in node.meta:
48+
module_values_list = list(node.meta["nn_module_stack"].values())
49+
full_qualified_name = module_values_list[-1][0]
50+
if full_qualified_name == "output.conv":
51+
annotate_conv2d(
52+
node, quantization_config=quantization_config_16a8w_per_channel
53+
)
54+
55+
56+
def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
57+
for node in gm.graph.nodes:
58+
if node.op == "output":
59+
for index, prefill_output in enumerate(node.args[0]):
60+
kv_quant_attr = kv_quant_attrs[index]
61+
fixed_observer = FixedQParamsObserver.with_args(
62+
scale=kv_quant_attr[0],
63+
zero_point=kv_quant_attr[1],
64+
quant_min=kv_quant_attr[2],
65+
quant_max=kv_quant_attr[3],
66+
dtype=kv_quant_attr[4],
67+
qscheme=torch.torch.per_tensor_affine,
68+
)
69+
70+
fixed_output_spec = QuantizationSpec(
71+
quant_min=kv_quant_attr[2],
72+
quant_max=kv_quant_attr[3],
73+
dtype=kv_quant_attr[4],
74+
ch_axis=0,
75+
observer_or_fake_quant_ctr=fixed_observer,
76+
)
77+
78+
input_qspec_map = {}
79+
for input in prefill_output.args:
80+
if isinstance(input, Node):
81+
input_qspec_map[input] = fixed_output_spec
82+
83+
prefill_output.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
84+
input_qspec_map=input_qspec_map,
85+
output_qspec=fixed_output_spec,
86+
_annotated=True,
87+
)
88+
89+
90+
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
2891
"""
2992
This function is specific for matmul op 16a8w.
3093
For k, we will tag such as the below, and
@@ -142,8 +205,7 @@ def annotate_matmul_input1(node: Node):
142205
for node in gm.graph.nodes:
143206
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
144207
annotate_matmul(node, quantization_config_16a8w)
145-
if traverse_input1:
146-
annotate_matmul_input1(node.args[1])
208+
annotate_matmul_input1(node.args[1])
147209

148210

149211
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3280,7 +3280,7 @@ def test_stories_single_llama(self):
32803280

32813281
cmds = [
32823282
"python",
3283-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama2/llama.py",
3283+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
32843284
"--artifact",
32853285
self.artifact_dir,
32863286
"--build_folder",
@@ -3307,6 +3307,8 @@ def test_stories_single_llama(self):
33073307
"16a4w",
33083308
"--temperature",
33093309
"0",
3310+
"--llama_model",
3311+
"stories110m",
33103312
]
33113313
if self.host:
33123314
cmds.extend(["--host", self.host])

examples/qualcomm/CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,8 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
8484
# build qnn_executor_runner
8585
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/executor_runner)
8686

87-
# build qnn_llama_runner for llama2
88-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama2)
89-
90-
# build qnn_llama_runner for llama3.2
91-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama3_2)
87+
# build qnn_llama_runner for llama
88+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama)
9289

9390
# build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner
9491
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama)

examples/qualcomm/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ This directory contains examples for some AI models.
44

55
We have seperated the example scripts into the following subfolders, please refer to [README.md](../../backends/qualcomm/README.md) for the example scripts' directory structure:
66

7-
1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.
7+
1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama](./oss_scripts/llama/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.
88

99
2. oss_scripts: OSS stands for Open Source Software. This folder contains python scripts for open source models. Some models under this folder might also have their own customized runner.
10-
For example, [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.
10+
For example, [llama](./oss_scripts/llama/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.
1111

1212
3. qaihub_scripts: QAIHub stands for [Qualcomm AI Hub](https://aihub.qualcomm.com/). On QAIHub, users can find pre-compiled context binaries, a format used by QNN to save its models. This provides users with a new option for model deployment. Different from oss_scripts & scripts, which the example scripts are converting a model from nn.Module to ExecuTorch .pte files, qaihub_scripts provides example scripts for converting pre-compiled context binaries to ExecuTorch .pte files. Additionaly, users can find customized example runners specific to the QAIHub models for execution. For example [qaihub_llama2_7b](./qaihub_scripts/llama2/qaihub_llama2_7b.py) is a script converting context binaries to ExecuTorch .pte files, and [qaihub_llama2_7b_runner](./qaihub_scripts/llama2/qaihub_llama2_7b_runner.cpp) is a customized example runner to execute llama2 .pte files. Please be aware that context-binaries downloaded from QAIHub are tied to a specific QNN SDK version.
1313
Before executing the scripts and runner, please ensure that you are using the QNN SDK version that is matching the context binary. Tutorial below will also cover how to check the QNN Version for a context binary.

examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt renamed to examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,35 @@ target_link_libraries(
1818
)
1919
target_link_options_shared_lib(custom_ops)
2020

21-
# preprocess qnn runner src files for llama3.2
22-
set(_llama3_2_runner__srcs ${_llama_runner__srcs})
23-
list(TRANSFORM _llama3_2_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
24-
list(FILTER _llama3_2_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
21+
# preprocess qnn runner src files for llama
22+
set(_llama_runner__srcs ${_llama_runner__srcs})
23+
list(TRANSFORM _llama_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
24+
list(FILTER _llama_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
2525
list(
2626
PREPEND
27-
_llama3_2_runner__srcs
28-
${CMAKE_CURRENT_LIST_DIR}/qnn_llama3_2_runner.cpp
27+
_llama_runner__srcs
28+
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
2929
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
3030
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
3131
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.cpp
3232
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.h
3333
)
3434

35-
list(
36-
APPEND _llama3_2_runner__srcs
37-
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
38-
)
3935
list(
4036
APPEND
41-
_llama3_2_runner__srcs
37+
_llama_runner__srcs
38+
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
4239
${CMAKE_CURRENT_SOURCE_DIR}/../../../models/llama/tokenizer/llama_tiktoken.cpp
4340
)
4441

45-
# build qnn llama3.2 1b runner
46-
add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs})
42+
# build qnn llama runner
43+
add_executable(qnn_llama_runner ${_llama_runner__srcs})
4744
target_include_directories(
48-
qnn_llama3_2_runner PUBLIC ${_common_include_directories}
45+
qnn_llama_runner PUBLIC ${_common_include_directories}
4946
)
5047

5148
target_link_libraries(
52-
qnn_llama3_2_runner
49+
qnn_llama_runner
5350
qnn_executorch_backend
5451
executorch_core
5552
extension_data_loader
@@ -60,8 +57,8 @@ target_link_libraries(
6057
custom_ops
6158
)
6259
target_compile_options(
63-
qnn_llama3_2_runner PUBLIC ${_common_compile_options}
60+
qnn_llama_runner PUBLIC ${_common_compile_options}
6461
)
6562
set_target_properties(
66-
qnn_llama3_2_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
63+
qnn_llama_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
6764
)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Summary
2+
3+
## Overview
4+
This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models:
5+
1. LLAMA2 Stories 110M
6+
2. LLAMA3.2 1B
7+
3. LLAMA3.2 3B (WIP)
8+
We offer the following modes to execute the model:
9+
10+
Prefill Mode: This is also known as batch prefill mode, where the model takes in a list of tokens as input and generates the next token along with the key-value (KV) cache for all tokens. This mode is efficient for generating the initial sequence of tokens (usually the user's prompt).
11+
12+
KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
13+
14+
Hybrid Mode: Hybrid mode leverages the strengths of both batch prefill and KV cache modes to optimize token generation speed. Initially, it uses prefill mode to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
15+
16+
17+
## Instructions
18+
### Note
19+
1. For hybrid mode, the export time will be longer and can take up to 1-4 hours to complete, depending on the specific model users are exporting.
20+
2. When exporting a hybrid mode model, memory consumption will be higher. Taking LLAMA3.2 1B as an example, please ensure the device has at least 80 GB of memory and swap space.
21+
22+
23+
### Step 1: Setup
24+
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch.
25+
2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend.
26+
27+
### Step 2: Prepare Model
28+
29+
#### LLAMA2
30+
Download and prepare stories110M model
31+
32+
```bash
33+
# tokenizer.model & stories110M.pt:
34+
wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
35+
wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"
36+
37+
# tokenizer.bin:
38+
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
39+
40+
# params.json:
41+
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
42+
```
43+
44+
#### LLAMA3.2
45+
Follow the [instructions](https://www.llama.com/) to download models.
46+
At the end of this step, users should have the following files ready: `consolidated.00.pth`, `params.json`, and `tokenizer.model`.
47+
48+
49+
### Step3: Run default examples using hybrid mode.
50+
#### LLAMA2
51+
```bash
52+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "Once upon a time"
53+
```
54+
55+
#### LLAMA3.2
56+
Default example using hybrid mode.
57+
```bash
58+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1"
59+
```
60+
61+
### Additional Configs when running the script
62+
If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example:
63+
```bash
64+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --compile_only
65+
```
66+
67+
On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example:
68+
```bash
69+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
70+
```

examples/qualcomm/oss_scripts/llama3_2/TARGETS renamed to examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ oncall("executorch")
88
python_binary(
99
name = "llama",
1010
srcs = ["llama.py"],
11-
main_function = "executorch.examples.qualcomm.oss_scripts.llama3_2.llama.main",
11+
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
1212
preload_deps = [
1313
"//executorch/extension/llm/custom_ops:model_sharding_py",
1414
],
1515
deps = [
16-
"//executorch/examples/qualcomm/oss_scripts/llama2:static_llama",
16+
"//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
1717
"//caffe2:torch",
1818
"//executorch/extension/pybindings:aten_lib",
1919
"//executorch/backends/qualcomm/partition:partition",

0 commit comments

Comments
 (0)