Skip to content

Commit 4cd737a

Browse files
author
Chun-I Tsai
committed
Qualcomm AI Engine Direct - Add llama sha transforming pass
- Add SHA pass
1 parent 1f2b9aa commit 4cd737a

File tree

4 files changed

+261
-17
lines changed

4 files changed

+261
-17
lines changed

examples/models/llama2/export_llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
# Example script for exporting Llama2 to flatbuffer
88

99
import logging
10+
import sys
1011

1112
import torch
1213

1314
from .export_llama_lib import build_args_parser, export_llama
1415

16+
sys.setrecursionlimit(4096)
17+
1518

1619
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
1720
logging.basicConfig(level=logging.INFO, format=FORMAT)

examples/models/llama2/export_llama_lib.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
fuse_layer_norms,
5151
get_model_with_r1_r2,
5252
)
53+
54+
from .source_transformation.attention import replace_attention_to_attention_sha
5355
from .source_transformation.quantize import (
5456
get_quant_embedding_transform,
5557
get_quant_weight_transform,
@@ -174,6 +176,12 @@ def build_args_parser() -> argparse.ArgumentParser:
174176
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
175177
)
176178

179+
parser.add_argument(
180+
"--use_qnn_sha",
181+
action="store_true",
182+
help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)",
183+
)
184+
177185
parser.add_argument(
178186
"--calibration_tasks",
179187
nargs="+",
@@ -917,14 +925,27 @@ def _get_source_transforms( # noqa
917925
convert_linear_to_conv2d,
918926
)
919927

920-
transforms.append(replace_kv_cache_with_simple_kv_cache)
921-
transforms.append(replace_sdpa_with_flex_sdpa)
922-
transforms.append(replace_causal_mask)
923-
transforms.append(replace_rms_norm_with_native_rms_norm)
924-
if args.optimized_rotation_path:
925-
transforms.append(fuse_layer_norms)
926-
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
927-
transforms.append(convert_linear_to_conv2d)
928+
if args.use_qnn_sha:
929+
if args.optimized_rotation_path:
930+
transforms.append(fuse_layer_norms)
931+
transforms.append(
932+
get_model_with_r1_r2(args.optimized_rotation_path)
933+
)
934+
transforms.append(replace_attention_to_attention_sha)
935+
transforms.append(replace_causal_mask)
936+
transforms.append(replace_rms_norm_with_native_rms_norm)
937+
transforms.append(convert_linear_to_conv2d)
938+
else:
939+
transforms.append(replace_kv_cache_with_simple_kv_cache)
940+
transforms.append(replace_sdpa_with_flex_sdpa)
941+
transforms.append(replace_causal_mask)
942+
transforms.append(replace_rms_norm_with_native_rms_norm)
943+
if args.optimized_rotation_path:
944+
transforms.append(fuse_layer_norms)
945+
transforms.append(
946+
get_model_with_r1_r2(args.optimized_rotation_path)
947+
)
948+
transforms.append(convert_linear_to_conv2d)
928949

929950
elif args.mps:
930951
# Currently mps doesn't support sdpa op, use the simpler decomposition

examples/models/llama2/llama_transformer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,21 +260,22 @@ class Attention(nn.Module):
260260
def __init__(self, args: ModelArgs, layer_id: int):
261261
super().__init__()
262262
self.use_kv_cache = args.use_kv_cache
263-
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
264-
assert args.n_heads % self.n_kv_heads == 0
263+
self.n_heads = args.n_heads
264+
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
265+
assert self.n_heads % self.n_kv_heads == 0
265266
model_parallel_size = 1
266-
self.n_local_heads = args.n_heads // model_parallel_size
267+
self.n_local_heads = self.n_heads // model_parallel_size
267268
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
268269
self.n_rep = self.n_local_heads // self.n_local_kv_heads
269-
self.head_dim = args.dim // args.n_heads
270+
self.head_dim = args.dim // self.n_heads
270271
self.max_batch_size = args.max_batch_size
271272
self.max_seq_len = args.max_seq_len
272273
self.dim = args.dim
273-
# args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125
274-
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
275-
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
276-
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
277-
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
274+
# args.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125
275+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
276+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
277+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
278+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
278279

279280
self.layer_id = layer_id
280281

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
# Example script for exporting Llama2 to flatbuffer
10+
11+
import math
12+
from typing import List, Optional, Tuple
13+
14+
import torch
15+
from executorch.examples.models.llama2.llama_transformer import Attention
16+
from torch import nn
17+
18+
19+
def apply_rotary_emb_single(
20+
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
21+
) -> torch.Tensor:
22+
x_r, x_i = x[..., ::2], x[..., 1::2]
23+
24+
x_out_r = x_r * freqs_cos - x_i * freqs_sin
25+
x_out_i = x_r * freqs_sin + x_i * freqs_cos
26+
27+
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
28+
return x_out
29+
30+
31+
class KVCacheSha(torch.nn.Module):
32+
def __init__(
33+
self,
34+
max_batch_size: int,
35+
max_seq_length: int,
36+
n_heads: int,
37+
head_dim: int,
38+
dtype=torch.float32,
39+
):
40+
super().__init__()
41+
42+
# a buffer per head
43+
cache_shape = (max_batch_size, max_seq_length, head_dim)
44+
for i in range(n_heads):
45+
self.register_buffer(
46+
f"past_k_caches_{i}",
47+
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
48+
persistent=False,
49+
)
50+
self.register_buffer(
51+
f"past_v_caches_{i}",
52+
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
53+
persistent=False,
54+
)
55+
56+
def update(
57+
self,
58+
input_pos: torch.Tensor,
59+
k_val: torch.Tensor,
60+
v_val: torch.Tensor,
61+
cache_idx: int,
62+
) -> Tuple[torch.Tensor, torch.Tensor]:
63+
new_k = torch.ops.aten.index_put_(
64+
getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val
65+
)
66+
new_v = torch.ops.aten.index_put_(
67+
getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val
68+
)
69+
return new_k, new_v
70+
71+
def get_cache(self, head_idx):
72+
return getattr(self, f"past_k_caches_{head_idx}"), getattr(
73+
self, f"past_v_caches_{head_idx}"
74+
)
75+
76+
77+
class SDPASha(torch.nn.Module):
78+
79+
def __init__(
80+
self,
81+
max_batch_size: int,
82+
max_seq_length: int,
83+
n_heads: int,
84+
n_rep: int,
85+
head_dim: int,
86+
dim: int,
87+
):
88+
super().__init__()
89+
self.head_dim = head_dim
90+
self.n_rep = n_rep
91+
self.dim = dim
92+
self.kv_cache = KVCacheSha(
93+
max_batch_size, max_seq_length, n_heads // n_rep, head_dim
94+
)
95+
self.scale_factor = math.sqrt(head_dim)
96+
97+
def forward(
98+
self,
99+
input_pos: torch.Tensor,
100+
qs: List[torch.Tensor],
101+
ks: List[torch.Tensor],
102+
vs: List[torch.Tensor],
103+
mask,
104+
):
105+
106+
transpose_ks = []
107+
for i in range(len(ks)):
108+
new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i)
109+
transpose_ks.append(new_k.transpose(-2, -1).contiguous())
110+
111+
output = []
112+
for i, q in enumerate(qs):
113+
cache_idx = i // self.n_rep
114+
_, v = self.kv_cache.get_cache(cache_idx)
115+
116+
attn_mask = mask[input_pos]
117+
118+
attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor
119+
attn_weight += attn_mask
120+
attn_weight = torch.softmax(attn_weight, dim=-1)
121+
output.append(attn_weight @ v.contiguous())
122+
123+
return torch.cat(output, dim=-1)
124+
125+
126+
class AttentionSha(nn.Module):
127+
def __init__(self, attention_mha: nn.Module):
128+
super().__init__()
129+
if not attention_mha.use_kv_cache:
130+
raise NotImplementedError("bert mode is not support")
131+
132+
self.n_heads = attention_mha.n_heads
133+
self.n_kv_heads = attention_mha.n_kv_heads
134+
self.n_rep = self.n_heads // self.n_kv_heads
135+
self.dim = attention_mha.dim
136+
self.max_batch_size = attention_mha.max_batch_size
137+
self.max_seq_len = attention_mha.max_seq_len
138+
self.head_dim = attention_mha.dim // self.n_heads
139+
self.SDPA = SDPASha(
140+
self.max_batch_size,
141+
self.max_seq_len,
142+
self.n_heads,
143+
self.n_rep,
144+
self.head_dim,
145+
self.dim,
146+
)
147+
self.wq = nn.ModuleList(
148+
[
149+
nn.Linear(self.dim, self.head_dim, bias=False)
150+
for _ in range(self.n_heads)
151+
]
152+
)
153+
self.wk = nn.ModuleList(
154+
[
155+
nn.Linear(self.dim, self.head_dim, bias=False)
156+
for _ in range(self.n_kv_heads)
157+
]
158+
)
159+
self.wv = nn.ModuleList(
160+
[
161+
nn.Linear(self.dim, self.head_dim, bias=False)
162+
for _ in range(self.n_kv_heads)
163+
]
164+
)
165+
166+
for i in range(self.n_heads):
167+
self.wq[i].weight.data.copy_(
168+
attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
169+
)
170+
for i in range(self.n_kv_heads):
171+
self.wk[i].weight.data.copy_(
172+
attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
173+
)
174+
self.wv[i].weight.data.copy_(
175+
attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim]
176+
)
177+
self.wo = attention_mha.wo
178+
179+
causal_mask = torch.tril(
180+
torch.ones(
181+
self.max_seq_len,
182+
self.max_seq_len,
183+
dtype=torch.bool,
184+
device="cpu",
185+
)
186+
)
187+
self.register_buffer("mask", causal_mask, persistent=False)
188+
189+
def forward(
190+
self,
191+
x: torch.Tensor,
192+
freqs_cos: torch.Tensor,
193+
freqs_sin: torch.Tensor,
194+
input_pos: Optional[torch.Tensor] = None,
195+
):
196+
# QKV
197+
q = [wq(x) for wq in self.wq]
198+
k = [wk(x) for wk in self.wk]
199+
v = [wv(x) for wv in self.wv]
200+
for i in range(len(q)):
201+
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
202+
for i in range(len(k)):
203+
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin)
204+
205+
output = self.SDPA(input_pos, q, k, v, self.mask)
206+
return self.wo(output)
207+
208+
209+
def replace_attention_to_attention_sha(module: torch.nn.Module):
210+
for name, child in module.named_children():
211+
if isinstance(child, Attention):
212+
setattr(
213+
module,
214+
name,
215+
AttentionSha(child),
216+
)
217+
else:
218+
replace_attention_to_attention_sha(child)
219+
return module

0 commit comments

Comments
 (0)