Skip to content

Commit 31dbfc9

Browse files
authored
Qualcomm AI Engine Direct - support llama3.2 1B/3B with static llama in kv mode
Differential Revision: D65843945 Pull Request resolved: #6779
1 parent 473bc7b commit 31dbfc9

File tree

13 files changed

+1492
-85
lines changed

13 files changed

+1492
-85
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
get_default_8bit_qnn_ptq_config,
1212
QuantizationConfig,
1313
)
14-
from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY
14+
from executorch.backends.qualcomm.quantizer.utils import (
15+
get_ptq_per_channel_quant_config,
16+
QUANT_ANNOTATION_KEY,
17+
)
1518
from executorch.exir.dialects._ops import ops as exir_ops
1619
from torch.ao.quantization.quantizer import (
1720
QuantizationAnnotation,
@@ -121,6 +124,36 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig):
121124
annotate_matmul_input1(node.args[1], quantization_config_8a8w)
122125

123126

127+
def custom_annotate_llama_last_conv_16a8w(gm: torch.fx.GraphModule) -> None:
128+
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
129+
input_qspec_map = {}
130+
input_act = node.args[0]
131+
input_spec = quantization_config.input_activation
132+
input_qspec_map[input_act] = input_spec
133+
134+
weight = node.args[1]
135+
input_qspec_map[weight] = quantization_config.weight
136+
137+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
138+
input_qspec_map=input_qspec_map,
139+
output_qspec=quantization_config.output_activation,
140+
_annotated=True,
141+
)
142+
143+
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
144+
torch.uint16, weight_dtype=torch.int8
145+
)
146+
for node in gm.graph.nodes:
147+
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
148+
if "nn_module_stack" in node.meta:
149+
module_values_list = list(node.meta["nn_module_stack"].values())
150+
full_qualified_name = module_values_list[0][0]
151+
if full_qualified_name == "L['self'].llama.output":
152+
annotate_conv2d(
153+
node, quantization_config=quantization_config_16a8w_per_channel
154+
)
155+
156+
124157
def custom_annotate_matmul_16a8w(gm: torch.fx.GraphModule):
125158
"""
126159
Annotate matmul op with 16a8w quantization config

backends/qualcomm/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def _transform(
331331
def capture_program(
332332
module: torch.nn.Module,
333333
inputs: Tuple[torch.Tensor],
334-
custom_pass_config: Set[str] = None,
334+
custom_pass_config: Set[str] = frozenset(),
335335
) -> exir.ExirExportedProgram:
336336
ep = torch.export.export(module, inputs)
337337
decomposed_ep = ep.run_decompositions(get_decomp_table())

examples/qualcomm/CMakeLists.txt

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,30 @@ target_include_directories(
6666
full_portable_ops_lib PUBLIC ${_common_include_directories}
6767
)
6868

69+
# find RE2 for tokenizer
70+
set(ABSL_ENABLE_INSTALL ON)
71+
set(ABSL_PROPAGATE_CXX_STD ON)
72+
set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE})
73+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
74+
add_subdirectory(
75+
${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/third-party/abseil-cpp
76+
${CMAKE_CURRENT_BINARY_DIR}/abseil-cpp
77+
)
78+
add_subdirectory(
79+
${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/third-party/re2
80+
${CMAKE_CURRENT_BINARY_DIR}/re2
81+
)
82+
set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
83+
6984
# build qnn_executor_runner
7085
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/executor_runner)
7186

72-
# build qnn_llama_runner
87+
# build qnn_llama_runner for llama2
7388
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama2)
7489

90+
# build qnn_llama_runner for llama3.2
91+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama3_2)
92+
7593
# build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner
7694
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama)
7795

examples/qualcomm/oss_scripts/llama2/llama.py

100644100755
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,6 @@ def lowering_modules(
333333
def get_example_inputs(self):
334334
return self.llama_model.get_example_inputs()
335335

336-
def get_export_inputs(self):
337-
return self.llama_model.get_export_inputs()
338-
339336

340337
def compile(args):
341338
os.makedirs(args.artifact, exist_ok=True)

examples/qualcomm/oss_scripts/llama2/model/static_llama.py

100644100755
Lines changed: 46 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@
1717
)
1818

1919

20+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
21+
"""
22+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
23+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
24+
"""
25+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
26+
if n_rep == 1:
27+
return hidden_states
28+
hidden_states = hidden_states[:, :, None, :, :].expand(
29+
batch, num_key_value_heads, n_rep, slen, head_dim
30+
)
31+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
32+
33+
2034
def apply_rotary_emb_single(
2135
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
2236
) -> torch.Tensor:
@@ -59,13 +73,13 @@ def prepare_sha(self):
5973
self.wk_sha = nn.ModuleList(
6074
[
6175
nn.Linear(self.dim, self.head_dim, bias=False)
62-
for _ in range(self.n_heads)
76+
for _ in range(self.n_kv_heads)
6377
]
6478
)
6579
self.wv_sha = nn.ModuleList(
6680
[
6781
nn.Linear(self.dim, self.head_dim, bias=False)
68-
for _ in range(self.n_heads)
82+
for _ in range(self.n_kv_heads)
6983
]
7084
)
7185

@@ -76,6 +90,7 @@ def prepare_sha(self):
7690
self.wq_sha[i].weight.data.copy_(
7791
self.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
7892
)
93+
for i in range(self.n_kv_heads):
7994
self.wk_sha[i].weight.data.copy_(
8095
self.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
8196
)
@@ -97,30 +112,27 @@ def forward_sha(
97112
v = [wv_sha(hidden_states) for wv_sha in self.wv_sha]
98113
for i in range(len(q)):
99114
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
115+
for i in range(len(k)):
100116
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1)
101117

102-
output_kh, output_vh, output_y = [], [], []
118+
output_y = []
119+
kh, vh = [], []
103120
for i, _ in enumerate(k_caches):
104-
# cat at the seq dim
105-
kh = torch.cat([k_caches[i], k[i]], dim=-1)
106-
vh = torch.cat([v_caches[i], v[i]], dim=1)
121+
kh.append(torch.cat([k_caches[i], k[i]], dim=-1))
122+
vh.append(torch.cat([v_caches[i], v[i]], dim=1))
107123

108-
attn = q[i] @ kh
124+
for i, _ in enumerate(q):
125+
cache_idx = i // self.num_key_value_groups
126+
attn = q[i] @ kh[cache_idx]
109127
attn = attn / self.scale + atten_mask
110128
attn = self.attn_softmax(attn)
111-
y = attn @ vh
129+
y = attn @ vh[cache_idx]
112130

113-
if self.output_new_cache_only:
114-
output_kh.append(k[i])
115-
output_vh.append(v[i])
116-
else:
117-
output_kh.append(kh)
118-
output_vh.append(vh)
119131
output_y.append(y)
120132

121133
y = torch.concat(output_y, dim=-1)
122134
y = self.wo(y)
123-
return y, output_kh, output_vh
135+
return y, k, v
124136

125137
def forward(
126138
self,
@@ -142,24 +154,28 @@ def forward(
142154
k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1)
143155

144156
output_kh, output_vh, output_y = [], [], []
145-
157+
kh, vh = [], []
146158
for i, _ in enumerate(k_caches):
147-
# cat at the seq dim
148-
kh = torch.cat([k_caches[i], k[:, i, :, :]], dim=-1)
149-
vh = torch.cat([v_caches[i], v[:, :, i, :]], dim=1)
159+
kh.append(torch.cat([k_caches[i], k[:, i, :, :]], dim=-1))
160+
vh.append(torch.cat([v_caches[i], v[:, :, i, :]], dim=1))
161+
162+
for i in range(self.n_heads):
163+
cache_idx = i // self.num_key_value_groups
150164

151-
attn = q[:, :, i, :] @ kh
165+
attn = q[:, :, i, :] @ kh[cache_idx]
152166
attn = attn / self.scale + atten_mask
153167
attn = self.attn_softmax(attn)
154-
y = attn @ vh
168+
y = attn @ vh[cache_idx]
155169

170+
output_y.append(y)
171+
172+
for i in range(len(k_caches)):
156173
if self.output_new_cache_only:
157174
output_kh.append(k[:, i, :, :])
158175
output_vh.append(v[:, :, i, :])
159176
else:
160-
output_kh.append(kh)
161-
output_vh.append(vh)
162-
output_y.append(y)
177+
output_kh.append(kh[i])
178+
output_vh.append(vh[i])
163179

164180
y = torch.concat(output_y, dim=-1)
165181
y = self.wo(y)
@@ -246,10 +262,10 @@ def forward(
246262

247263
hidden_states = self.tok_embeddings(tokens)
248264
for ind, decoder_layer in enumerate(self.layers):
249-
offset_k = ind * self.n_heads
250-
offset_v = self.n_layers * self.n_heads + offset_k
251-
k_caches = args[offset_k : offset_k + self.n_heads]
252-
v_caches = args[offset_v : offset_v + self.n_heads]
265+
offset_k = ind * self.n_kv_heads
266+
offset_v = self.n_layers * self.n_kv_heads + offset_k
267+
k_caches = args[offset_k : offset_k + self.n_kv_heads]
268+
v_caches = args[offset_v : offset_v + self.n_kv_heads]
253269
hidden_states, k, v = decoder_layer(
254270
hidden_states,
255271
freqs_cos=freqs_cos,
@@ -275,7 +291,7 @@ def get_example_inputs(self):
275291
atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0)
276292
atten_mask[:, -1] = 0
277293
for _ in range(self.n_layers):
278-
for _ in range(self.n_heads):
294+
for _ in range(self.n_kv_heads):
279295
# transpose first to decrease the runtime efforts
280296
k_cache.append(
281297
torch.zeros(
@@ -299,40 +315,6 @@ def get_example_inputs(self):
299315
v_cache,
300316
)
301317

302-
def get_export_inputs(self):
303-
tokens = torch.randint(
304-
self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32
305-
)
306-
pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32)
307-
# this is important for torch.export not to take it as dummy input
308-
k_cache, v_cache = [], []
309-
atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0)
310-
atten_mask[:, -1] = 0
311-
for _ in range(self.n_layers):
312-
for _ in range(self.n_heads):
313-
# transpose first to decrease the runtime efforts
314-
k_cache.append(
315-
torch.randn(
316-
self.max_batch_size,
317-
self.head_dim,
318-
self.max_seq_len - 1,
319-
)
320-
)
321-
v_cache.append(
322-
torch.randn(
323-
self.max_batch_size,
324-
self.max_seq_len - 1,
325-
self.head_dim,
326-
)
327-
)
328-
return (
329-
tokens,
330-
pos_ids,
331-
atten_mask,
332-
k_cache,
333-
v_cache,
334-
)
335-
336318
def get_metadata(self):
337319
# TODO: modify this when enabling LLAMA 7B
338320
return {
@@ -344,7 +326,7 @@ def get_metadata(self):
344326
"get_max_seq_len": self.max_seq_len,
345327
"get_n_bos": 1,
346328
"get_n_eos": 1,
347-
"get_n_kv_heads": self.n_heads,
329+
"get_n_kv_heads": self.n_kv_heads,
348330
"get_n_layers": self.n_layers,
349331
"get_vocab_size": self.vocab_size,
350332
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# model sharding with custom op
8+
set(CUSTOM_OP_SRCS_FILE
9+
"${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp"
10+
)
11+
add_library(custom_ops ${CUSTOM_OP_SRCS_FILE})
12+
target_include_directories(custom_ops PUBLIC "${_common_include_directories}")
13+
target_include_directories(
14+
custom_ops PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../include"
15+
)
16+
target_link_libraries(
17+
custom_ops PUBLIC full_portable_ops_lib
18+
)
19+
target_link_options_shared_lib(custom_ops)
20+
21+
# preprocess qnn runner src files for llama3.2
22+
set(_llama3_2_runner__srcs ${_llama_runner__srcs})
23+
list(TRANSFORM _llama3_2_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
24+
list(FILTER _llama3_2_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
25+
list(
26+
PREPEND
27+
_llama3_2_runner__srcs
28+
${CMAKE_CURRENT_LIST_DIR}/qnn_llama3_2_runner.cpp
29+
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
30+
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
31+
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.cpp
32+
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.h
33+
)
34+
35+
list(
36+
APPEND _llama3_2_runner__srcs
37+
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
38+
)
39+
list(
40+
APPEND
41+
_llama3_2_runner__srcs
42+
${CMAKE_CURRENT_SOURCE_DIR}/../../../models/llama/tokenizer/llama_tiktoken.cpp
43+
)
44+
45+
# build qnn llama3.2 1b runner
46+
add_executable(qnn_llama3_2_1b_runner ${_llama3_2_runner__srcs})
47+
target_include_directories(
48+
qnn_llama3_2_1b_runner PUBLIC ${_common_include_directories}
49+
)
50+
51+
target_link_libraries(
52+
qnn_llama3_2_1b_runner
53+
qnn_executorch_backend
54+
executorch_core
55+
extension_data_loader
56+
extension_module
57+
extension_tensor
58+
gflags
59+
re2::re2
60+
custom_ops
61+
)
62+
target_compile_options(
63+
qnn_llama3_2_1b_runner PUBLIC ${_common_compile_options}
64+
)
65+
set_target_properties(
66+
qnn_llama3_2_1b_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
67+
)
68+
69+
70+
# build qnn llama3.2 3b runner
71+
add_executable(qnn_llama3_2_3b_runner ${_llama3_2_runner__srcs})
72+
target_include_directories(
73+
qnn_llama3_2_3b_runner PUBLIC ${_common_include_directories}
74+
)
75+
# Adding compile option to differentiate llama3.2 1b with 3b
76+
target_compile_options(qnn_llama3_2_3b_runner PRIVATE -DLLAMA3_2_3B_RUNNER)
77+
78+
target_link_libraries(
79+
qnn_llama3_2_3b_runner
80+
qnn_executorch_backend
81+
executorch_core
82+
extension_data_loader
83+
extension_module
84+
extension_tensor
85+
gflags
86+
re2::re2
87+
custom_ops
88+
)
89+
target_compile_options(
90+
qnn_llama3_2_3b_runner PUBLIC ${_common_compile_options}
91+
)
92+
set_target_properties(
93+
qnn_llama3_2_3b_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
94+
)

0 commit comments

Comments
 (0)