Skip to content

Commit 576e96c

Browse files
authored
Qualcomm AI Engine Direct - Add llama sha transforming pass
Differential Revision: D64435128 Pull Request resolved: #6211
1 parent 623a9a6 commit 576e96c

File tree

5 files changed

+267
-16
lines changed

5 files changed

+267
-16
lines changed

examples/models/llama/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ runtime.python_library(
8282
"export_llama_lib.py",
8383
"model.py",
8484
"source_transformation/apply_spin_quant_r1_r2.py",
85+
"source_transformation/attention.py",
8586
"source_transformation/lora.py",
8687
"source_transformation/pre_quantization.py",
8788
"source_transformation/prune_vocab.py",

examples/models/llama/export_llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
# Example script for exporting Llama2 to flatbuffer
88

99
import logging
10+
import sys
1011

1112
import torch
1213

1314
from .export_llama_lib import build_args_parser, export_llama
1415

16+
sys.setrecursionlimit(4096)
17+
1518

1619
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
1720
logging.basicConfig(level=logging.INFO, format=FORMAT)

examples/models/llama/export_llama_lib.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
fuse_layer_norms,
5151
get_model_with_r1_r2,
5252
)
53+
54+
from .source_transformation.attention import replace_attention_to_attention_sha
5355
from .source_transformation.quantize import (
5456
get_quant_embedding_transform,
5557
get_quant_weight_transform,
@@ -175,6 +177,12 @@ def build_args_parser() -> argparse.ArgumentParser:
175177
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
176178
)
177179

180+
parser.add_argument(
181+
"--use_qnn_sha",
182+
action="store_true",
183+
help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)",
184+
)
185+
178186
parser.add_argument(
179187
"--calibration_tasks",
180188
nargs="+",
@@ -700,15 +708,24 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
700708
get_custom_quant_ios_dtype,
701709
)
702710

711+
atten = builder_exported_to_edge.model.layers[0].attention
712+
if args.use_qnn_sha:
713+
cache_shape = torch.Size(
714+
(atten.max_batch_size, atten.max_seq_len, atten.head_dim)
715+
)
716+
else:
717+
cache_shape = torch.Size(
718+
(
719+
atten.max_batch_size,
720+
atten.max_seq_len,
721+
atten.n_kv_heads,
722+
atten.head_dim,
723+
)
724+
)
703725
# pyre-ignore
704726
tag_quant_io(
705727
builder_exported_to_edge.edge_manager.exported_program().graph_module,
706-
partial(
707-
get_custom_quant_ios_dtype, # pyre-ignore
708-
builder_exported_to_edge.model.layers[
709-
0
710-
].attention.kv_cache.past_k_caches.shape,
711-
),
728+
partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore
712729
)
713730

714731
logging.info("Lowering model using following partitioner(s): ")
@@ -977,15 +994,27 @@ def _get_source_transforms( # noqa
977994
convert_linear_to_conv2d,
978995
)
979996

980-
transforms.append(replace_kv_cache_with_simple_kv_cache)
981-
transforms.append(replace_sdpa_with_flex_sdpa)
982-
transforms.append(replace_causal_mask)
983-
transforms.append(replace_rms_norm_with_native_rms_norm)
984-
if args.optimized_rotation_path:
985-
transforms.append(fuse_layer_norms)
986-
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
987-
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
988-
transforms.append(convert_linear_to_conv2d)
997+
if args.use_qnn_sha:
998+
if args.optimized_rotation_path:
999+
transforms.append(fuse_layer_norms)
1000+
transforms.append(
1001+
get_model_with_r1_r2(args.optimized_rotation_path)
1002+
)
1003+
transforms.append(replace_attention_to_attention_sha)
1004+
transforms.append(replace_causal_mask)
1005+
transforms.append(replace_rms_norm_with_native_rms_norm)
1006+
transforms.append(convert_linear_to_conv2d)
1007+
else:
1008+
transforms.append(replace_kv_cache_with_simple_kv_cache)
1009+
transforms.append(replace_sdpa_with_flex_sdpa)
1010+
transforms.append(replace_causal_mask)
1011+
transforms.append(replace_rms_norm_with_native_rms_norm)
1012+
if args.optimized_rotation_path:
1013+
transforms.append(fuse_layer_norms)
1014+
transforms.append(
1015+
get_model_with_r1_r2(args.optimized_rotation_path)
1016+
)
1017+
transforms.append(convert_linear_to_conv2d)
9891018

9901019
elif args.mps:
9911020
# Currently mps doesn't support sdpa op, use the simpler decomposition

examples/models/llama/llama_transformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
276276
self.max_batch_size = args.max_batch_size
277277
self.max_seq_len = args.max_seq_len
278278
self.dim = args.dim
279-
# self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125
280279
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
281280
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
282281
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
# Example script for exporting Llama2 to flatbuffer
10+
11+
import math
12+
from typing import List, Optional, Tuple
13+
14+
import torch
15+
from executorch.examples.models.llama.llama_transformer import Attention
16+
from torch import nn
17+
18+
19+
def apply_rotary_emb_single(
20+
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
21+
) -> torch.Tensor:
22+
x_r, x_i = x[..., ::2], x[..., 1::2]
23+
24+
x_out_r = x_r * freqs_cos - x_i * freqs_sin
25+
x_out_i = x_r * freqs_sin + x_i * freqs_cos
26+
27+
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
28+
return x_out
29+
30+
31+
class KVCacheSHA(torch.nn.Module):
32+
def __init__(
33+
self,
34+
max_batch_size: int,
35+
max_seq_length: int,
36+
n_heads: int,
37+
head_dim: int,
38+
dtype=torch.float32,
39+
):
40+
super().__init__()
41+
42+
# a buffer per head
43+
cache_shape = (max_batch_size, max_seq_length, head_dim)
44+
for i in range(n_heads):
45+
self.register_buffer(
46+
f"past_k_caches_{i}",
47+
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
48+
persistent=False,
49+
)
50+
self.register_buffer(
51+
f"past_v_caches_{i}",
52+
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
53+
persistent=False,
54+
)
55+
56+
def update(
57+
self,
58+
input_pos: torch.Tensor,
59+
k_val: torch.Tensor,
60+
v_val: torch.Tensor,
61+
cache_idx: int,
62+
) -> Tuple[torch.Tensor, torch.Tensor]:
63+
new_k = torch.ops.aten.index_put_(
64+
getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val
65+
)
66+
new_v = torch.ops.aten.index_put_(
67+
getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val
68+
)
69+
return new_k, new_v
70+
71+
def get_cache(self, head_idx):
72+
return getattr(self, f"past_k_caches_{head_idx}"), getattr(
73+
self, f"past_v_caches_{head_idx}"
74+
)
75+
76+
77+
class SDPASHA(torch.nn.Module):
78+
79+
def __init__(
80+
self,
81+
max_batch_size: int,
82+
max_seq_length: int,
83+
n_heads: int,
84+
n_rep: int,
85+
head_dim: int,
86+
dim: int,
87+
):
88+
super().__init__()
89+
self.head_dim = head_dim
90+
self.n_rep = n_rep
91+
self.dim = dim
92+
self.kv_cache = KVCacheSHA(
93+
max_batch_size, max_seq_length, n_heads // n_rep, head_dim
94+
)
95+
self.scale_factor = math.sqrt(head_dim)
96+
97+
def forward(
98+
self,
99+
input_pos: torch.Tensor,
100+
qs: List[torch.Tensor],
101+
ks: List[torch.Tensor],
102+
vs: List[torch.Tensor],
103+
mask,
104+
):
105+
106+
transpose_ks = []
107+
for i in range(len(ks)):
108+
new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i)
109+
transpose_ks.append(new_k.transpose(-2, -1).contiguous())
110+
111+
output = []
112+
for i, q in enumerate(qs):
113+
cache_idx = i // self.n_rep
114+
_, v = self.kv_cache.get_cache(cache_idx)
115+
116+
attn_mask = mask[input_pos]
117+
118+
attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor
119+
attn_weight += attn_mask
120+
attn_weight = torch.softmax(attn_weight, dim=-1)
121+
output.append(attn_weight @ v.contiguous())
122+
123+
return torch.cat(output, dim=-1)
124+
125+
126+
class AttentionSHA(nn.Module):
127+
def __init__(self, attention_mha: nn.Module):
128+
super().__init__()
129+
if not attention_mha.use_kv_cache:
130+
raise NotImplementedError("bert mode is not support")
131+
132+
self.n_heads = attention_mha.n_heads
133+
self.n_kv_heads = attention_mha.n_kv_heads
134+
self.n_rep = self.n_heads // self.n_kv_heads
135+
self.dim = attention_mha.dim
136+
self.max_batch_size = attention_mha.max_batch_size
137+
self.max_seq_len = attention_mha.max_seq_len
138+
self.head_dim = attention_mha.dim // self.n_heads
139+
self.SDPA = SDPASHA(
140+
self.max_batch_size,
141+
self.max_seq_len,
142+
self.n_heads,
143+
self.n_rep,
144+
self.head_dim,
145+
self.dim,
146+
)
147+
self.wq = nn.ModuleList(
148+
[
149+
nn.Linear(self.dim, self.head_dim, bias=False)
150+
for _ in range(self.n_heads)
151+
]
152+
)
153+
self.wk = nn.ModuleList(
154+
[
155+
nn.Linear(self.dim, self.head_dim, bias=False)
156+
for _ in range(self.n_kv_heads)
157+
]
158+
)
159+
self.wv = nn.ModuleList(
160+
[
161+
nn.Linear(self.dim, self.head_dim, bias=False)
162+
for _ in range(self.n_kv_heads)
163+
]
164+
)
165+
166+
for i in range(self.n_heads):
167+
self.wq[i].weight.data.copy_(
168+
attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
169+
)
170+
for i in range(self.n_kv_heads):
171+
self.wk[i].weight.data.copy_(
172+
attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
173+
)
174+
self.wv[i].weight.data.copy_(
175+
attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim]
176+
)
177+
self.wo = attention_mha.wo
178+
179+
causal_mask = torch.tril(
180+
torch.ones(
181+
self.max_seq_len,
182+
self.max_seq_len,
183+
dtype=torch.bool,
184+
device="cpu",
185+
)
186+
)
187+
self.register_buffer("mask", causal_mask, persistent=False)
188+
189+
def forward(
190+
self,
191+
x: torch.Tensor,
192+
freqs_cos: torch.Tensor,
193+
freqs_sin: torch.Tensor,
194+
input_pos: Optional[torch.Tensor] = None,
195+
):
196+
# QKV
197+
q = [wq(x) for wq in self.wq]
198+
k = [wk(x) for wk in self.wk]
199+
v = [wv(x) for wv in self.wv]
200+
for i in range(len(q)):
201+
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
202+
for i in range(len(k)):
203+
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin)
204+
205+
output = self.SDPA(input_pos, q, k, v, self.mask)
206+
return self.wo(output)
207+
208+
209+
def replace_attention_to_attention_sha(module: torch.nn.Module):
210+
for name, child in module.named_children():
211+
if isinstance(child, Attention):
212+
setattr(
213+
module,
214+
name,
215+
AttentionSHA(child),
216+
)
217+
else:
218+
replace_attention_to_attention_sha(child)
219+
return module

0 commit comments

Comments
 (0)