Skip to content

Commit 65e3c00

Browse files
authored
Revert "Qualcomm AI Engine Direct - Unify Llama2&Llama3 and Small Accuracy Improvement." (#7889)
Revert "Qualcomm AI Engine Direct - Unify Llama2&Llama3 and Small Accuracy Im…" This reverts commit 66bfd75.
1 parent bebceb7 commit 65e3c00

24 files changed

+2022
-332
lines changed

backends/qualcomm/_passes/insert_requantize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,15 @@ def _single_output_annotation(
8989
requantize_dict = n.meta.pop(QCOM_REQUANTIZE)
9090
# {quant_attr: user_node_name_list}
9191
group_quant_attr_dict = self._invert_dict(requantize_dict)
92+
# TODO: If users of the node contain output node,
93+
# we replace the node with to_copy op. However, it would
94+
# be problem when the node has multiple to_copy ops
95+
add_output = len(group_quant_attr_dict) == 1
9296

9397
for hashable_quant_attr, user_nodes in group_quant_attr_dict.items():
9498
user_nodes_copy = user_nodes.copy()
99+
if add_output:
100+
user_nodes_copy.append("output")
95101
self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy)
96102

97103
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 6 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,80 +14,17 @@
1414
QuantizationConfig,
1515
)
1616
from executorch.exir.dialects._ops import ops as exir_ops
17-
from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver
17+
from torch.ao.quantization.observer import MinMaxObserver
1818
from torch.ao.quantization.quantizer import (
1919
QuantizationAnnotation,
20-
QuantizationSpec,
2120
SharedQuantizationSpec,
2221
)
2322
from torch.fx import Node
2423

2524

26-
def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None:
27-
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
28-
input_qspec_map = {}
29-
input_act = node.args[0]
30-
input_spec = quantization_config.input_activation
31-
input_qspec_map[input_act] = input_spec
32-
33-
weight = node.args[1]
34-
input_qspec_map[weight] = quantization_config.weight
35-
36-
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
37-
input_qspec_map=input_qspec_map,
38-
output_qspec=quantization_config.output_activation,
39-
_annotated=True,
40-
)
41-
42-
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
43-
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
44-
)
45-
for node in gm.graph.nodes:
46-
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
47-
if "nn_module_stack" in node.meta:
48-
module_values_list = list(node.meta["nn_module_stack"].values())
49-
full_qualified_name = module_values_list[-1][0]
50-
if full_qualified_name == "output.conv":
51-
annotate_conv2d(
52-
node, quantization_config=quantization_config_16a8w_per_channel
53-
)
54-
55-
56-
def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
57-
for node in gm.graph.nodes:
58-
if node.op == "output":
59-
for index, prefill_output in enumerate(node.args[0]):
60-
kv_quant_attr = kv_quant_attrs[index]
61-
fixed_observer = FixedQParamsObserver.with_args(
62-
scale=kv_quant_attr[0],
63-
zero_point=kv_quant_attr[1],
64-
quant_min=kv_quant_attr[2],
65-
quant_max=kv_quant_attr[3],
66-
dtype=kv_quant_attr[4],
67-
qscheme=torch.torch.per_tensor_affine,
68-
)
69-
70-
fixed_output_spec = QuantizationSpec(
71-
quant_min=kv_quant_attr[2],
72-
quant_max=kv_quant_attr[3],
73-
dtype=kv_quant_attr[4],
74-
ch_axis=0,
75-
observer_or_fake_quant_ctr=fixed_observer,
76-
)
77-
78-
input_qspec_map = {}
79-
for input in prefill_output.args:
80-
if isinstance(input, Node):
81-
input_qspec_map[input] = fixed_output_spec
82-
83-
prefill_output.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
84-
input_qspec_map=input_qspec_map,
85-
output_qspec=fixed_output_spec,
86-
_annotated=True,
87-
)
88-
89-
90-
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
25+
def annotate_matmul_16a8w( # noqa: C901
26+
gm: torch.fx.GraphModule, traverse_input1=True
27+
) -> None:
9128
"""
9229
This function is specific for matmul op 16a8w.
9330
For k, we will tag such as the below, and
@@ -205,7 +142,8 @@ def annotate_matmul_input1(node: Node):
205142
for node in gm.graph.nodes:
206143
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
207144
annotate_matmul(node, quantization_config_16a8w)
208-
annotate_matmul_input1(node.args[1])
145+
if traverse_input1:
146+
annotate_matmul_input1(node.args[1])
209147

210148

211149
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3529,7 +3529,7 @@ def test_stories_single_llama(self):
35293529

35303530
cmds = [
35313531
"python",
3532-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
3532+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama2/llama.py",
35333533
"--artifact",
35343534
self.artifact_dir,
35353535
"--build_folder",
@@ -3556,8 +3556,6 @@ def test_stories_single_llama(self):
35563556
"16a4w",
35573557
"--temperature",
35583558
"0",
3559-
"--llama_model",
3560-
"stories110m",
35613559
]
35623560
if self.host:
35633561
cmds.extend(["--host", self.host])

examples/qualcomm/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,11 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
8484
# build qnn_executor_runner
8585
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/executor_runner)
8686

87-
# build qnn_llama_runner for llama
88-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama)
87+
# build qnn_llama_runner for llama2
88+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama2)
89+
90+
# build qnn_llama_runner for llama3.2
91+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama3_2)
8992

9093
# build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner
9194
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama)

examples/qualcomm/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ This directory contains examples for some AI models.
44

55
We have seperated the example scripts into the following subfolders, please refer to [README.md](../../backends/qualcomm/README.md) for the example scripts' directory structure:
66

7-
1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama](./oss_scripts/llama/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.
7+
1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.
88

99
2. oss_scripts: OSS stands for Open Source Software. This folder contains python scripts for open source models. Some models under this folder might also have their own customized runner.
10-
For example, [llama](./oss_scripts/llama/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.
10+
For example, [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.
1111

1212
3. qaihub_scripts: QAIHub stands for [Qualcomm AI Hub](https://aihub.qualcomm.com/). On QAIHub, users can find pre-compiled context binaries, a format used by QNN to save its models. This provides users with a new option for model deployment. Different from oss_scripts & scripts, which the example scripts are converting a model from nn.Module to ExecuTorch .pte files, qaihub_scripts provides example scripts for converting pre-compiled context binaries to ExecuTorch .pte files. Additionaly, users can find customized example runners specific to the QAIHub models for execution. For example [qaihub_llama2_7b](./qaihub_scripts/llama2/qaihub_llama2_7b.py) is a script converting context binaries to ExecuTorch .pte files, and [qaihub_llama2_7b_runner](./qaihub_scripts/llama2/qaihub_llama2_7b_runner.cpp) is a customized example runner to execute llama2 .pte files. Please be aware that context-binaries downloaded from QAIHub are tied to a specific QNN SDK version.
1313
Before executing the scripts and runner, please ensure that you are using the QNN SDK version that is matching the context binary. Tutorial below will also cover how to check the QNN Version for a context binary.

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 0 additions & 70 deletions
This file was deleted.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
set(_qnn_llama_runner__srcs ${_llama_runner__srcs})
8+
9+
# preprocess qnn llama runner src files
10+
list(TRANSFORM _qnn_llama_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
11+
list(FILTER _qnn_llama_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
12+
list(
13+
PREPEND
14+
_qnn_llama_runner__srcs
15+
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
16+
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
17+
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
18+
)
19+
20+
# build qnn llama runner
21+
add_executable(qnn_llama_runner ${_qnn_llama_runner__srcs})
22+
target_include_directories(
23+
qnn_llama_runner PUBLIC ${_common_include_directories}
24+
)
25+
target_link_libraries(
26+
qnn_llama_runner
27+
qnn_executorch_backend
28+
full_portable_ops_lib
29+
extension_data_loader
30+
extension_module
31+
extension_tensor
32+
gflags
33+
re2::re2
34+
)
35+
target_compile_options(qnn_llama_runner PUBLIC ${_common_compile_options})
36+
set_target_properties(
37+
qnn_llama_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
38+
)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Summary
2+
3+
## Overview
4+
This file provides you the instructions to run LLAMA2 with different parameters via Qualcomm HTP backend. Following settings support for Stories 110M
5+
6+
Please check corresponding section for more information.
7+
8+
## Stories 110M
9+
This example demonstrates how to run a smaller LLAMA2, stories110M on mobile via Qualcomm HTP backend. Model architecture is fine-tuned specifically for HTP to accelerate the performance. Weight is quantized via PTQ quantization to fit the model on a phone.
10+
11+
### Instructions
12+
#### Step 1: Setup
13+
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch.
14+
2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend.
15+
16+
#### Step2: Prepare Model
17+
Download and preapre stories110M model
18+
19+
```bash
20+
# tokenizer.model & stories110M.pt:
21+
wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
22+
wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"
23+
24+
# tokenizer.bin:
25+
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
26+
27+
# params.json:
28+
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
29+
```
30+
31+
#### Step3: Run default examples
32+
Default example generates the story based on the given prompt, "Once".
33+
```bash
34+
# 16a4w quant:
35+
python examples/qualcomm/oss_scripts/llama2/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --prompt "Once"
36+
```
37+
38+
#### (Note) Customized PTQ data set
39+
User prompts are used for PTQ calibration data. Take the examples above, the word "Once" is the only word for PTQ. If you want to observe more data during the calibration time. Please add more prompts to the args `--prompt`.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
load("@fbsource//xplat/executorch/backends/qualcomm/qnn_version.bzl", "get_qnn_library_verision")
3+
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5+
6+
oncall("executorch")
7+
8+
9+
python_library(
10+
name = "static_llama",
11+
srcs = [
12+
"model/static_llama.py",
13+
],
14+
deps = [
15+
"//caffe2:torch",
16+
],
17+
)
18+
19+
python_binary(
20+
name = "llama",
21+
srcs = ["llama.py"],
22+
main_function = "executorch.examples.qualcomm.oss_scripts.llama2.llama.main",
23+
deps = [
24+
":static_llama",
25+
"//caffe2:torch",
26+
"//executorch/extension/pybindings:aten_lib",
27+
"//executorch/backends/qualcomm/partition:partition",
28+
"//executorch/backends/qualcomm/quantizer:quantizer",
29+
"//executorch/devtools:lib",
30+
"//executorch/examples/models:models",
31+
"//executorch/examples/qualcomm:utils",
32+
"//executorch/extension/export_util:export_util",
33+
"//executorch/extension/llm/export:export_lib",
34+
],
35+
)
36+
37+
runtime.command_alias(
38+
name = "llama_qnn",
39+
env = {
40+
"LD_LIBRARY_PATH": "$(location fbsource//third-party/qualcomm/qnn/qnn-{0}:qnn_offline_compile_libs)".format(get_qnn_library_verision()),
41+
},
42+
exe = ":llama",
43+
)

0 commit comments

Comments
 (0)