Skip to content

Commit 2f85f9e

Browse files
Chun-I TsaiJoey Tsai
authored andcommitted
Qualcomm AI Engine Direct - Add llama sha transforming pass
- Add SHA pass
1 parent 5b51bb8 commit 2f85f9e

File tree

4 files changed

+251
-10
lines changed

4 files changed

+251
-10
lines changed

examples/models/llama/export_llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
# Example script for exporting Llama2 to flatbuffer
88

99
import logging
10+
import sys
1011

1112
import torch
1213

1314
from .export_llama_lib import build_args_parser, export_llama
1415

16+
sys.setrecursionlimit(4096)
17+
1518

1619
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
1720
logging.basicConfig(level=logging.INFO, format=FORMAT)

examples/models/llama/export_llama_lib.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
fuse_layer_norms,
5151
get_model_with_r1_r2,
5252
)
53+
54+
from .source_transformation.attention import replace_attention_to_attention_sha
5355
from .source_transformation.quantize import (
5456
get_quant_embedding_transform,
5557
get_quant_weight_transform,
@@ -175,6 +177,12 @@ def build_args_parser() -> argparse.ArgumentParser:
175177
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
176178
)
177179

180+
parser.add_argument(
181+
"--use_qnn_sha",
182+
action="store_true",
183+
help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)",
184+
)
185+
178186
parser.add_argument(
179187
"--calibration_tasks",
180188
nargs="+",
@@ -947,15 +955,27 @@ def _get_source_transforms( # noqa
947955
convert_linear_to_conv2d,
948956
)
949957

950-
transforms.append(replace_kv_cache_with_simple_kv_cache)
951-
transforms.append(replace_sdpa_with_flex_sdpa)
952-
transforms.append(replace_causal_mask)
953-
transforms.append(replace_rms_norm_with_native_rms_norm)
954-
if args.optimized_rotation_path:
955-
transforms.append(fuse_layer_norms)
956-
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
957-
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
958-
transforms.append(convert_linear_to_conv2d)
958+
if args.use_qnn_sha:
959+
if args.optimized_rotation_path:
960+
transforms.append(fuse_layer_norms)
961+
transforms.append(
962+
get_model_with_r1_r2(args.optimized_rotation_path)
963+
)
964+
transforms.append(replace_attention_to_attention_sha)
965+
transforms.append(replace_causal_mask)
966+
transforms.append(replace_rms_norm_with_native_rms_norm)
967+
transforms.append(convert_linear_to_conv2d)
968+
else:
969+
transforms.append(replace_kv_cache_with_simple_kv_cache)
970+
transforms.append(replace_sdpa_with_flex_sdpa)
971+
transforms.append(replace_causal_mask)
972+
transforms.append(replace_rms_norm_with_native_rms_norm)
973+
if args.optimized_rotation_path:
974+
transforms.append(fuse_layer_norms)
975+
transforms.append(
976+
get_model_with_r1_r2(args.optimized_rotation_path)
977+
)
978+
transforms.append(convert_linear_to_conv2d)
959979

960980
elif args.mps:
961981
# Currently mps doesn't support sdpa op, use the simpler decomposition

examples/models/llama/llama_transformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
276276
self.max_batch_size = args.max_batch_size
277277
self.max_seq_len = args.max_seq_len
278278
self.dim = args.dim
279-
# self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125
280279
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
281280
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
282281
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
# Example script for exporting Llama2 to flatbuffer
10+
11+
import math
12+
from typing import List, Optional, Tuple
13+
14+
import torch
15+
from executorch.examples.models.llama2.llama_transformer import Attention
16+
from torch import nn
17+
18+
19+
def apply_rotary_emb_single(
20+
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
21+
) -> torch.Tensor:
22+
x_r, x_i = x[..., ::2], x[..., 1::2]
23+
24+
x_out_r = x_r * freqs_cos - x_i * freqs_sin
25+
x_out_i = x_r * freqs_sin + x_i * freqs_cos
26+
27+
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
28+
return x_out
29+
30+
31+
class KVCacheSha(torch.nn.Module):
32+
def __init__(
33+
self,
34+
max_batch_size: int,
35+
max_seq_length: int,
36+
n_heads: int,
37+
head_dim: int,
38+
dtype=torch.float32,
39+
):
40+
super().__init__()
41+
42+
# a buffer per head
43+
cache_shape = (max_batch_size, max_seq_length, head_dim)
44+
for i in range(n_heads):
45+
self.register_buffer(
46+
f"past_k_caches_{i}",
47+
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
48+
persistent=False,
49+
)
50+
self.register_buffer(
51+
f"past_v_caches_{i}",
52+
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
53+
persistent=False,
54+
)
55+
56+
def update(
57+
self,
58+
input_pos: torch.Tensor,
59+
k_val: torch.Tensor,
60+
v_val: torch.Tensor,
61+
cache_idx: int,
62+
) -> Tuple[torch.Tensor, torch.Tensor]:
63+
new_k = torch.ops.aten.index_put_(
64+
getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val
65+
)
66+
new_v = torch.ops.aten.index_put_(
67+
getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val
68+
)
69+
return new_k, new_v
70+
71+
def get_cache(self, head_idx):
72+
return getattr(self, f"past_k_caches_{head_idx}"), getattr(
73+
self, f"past_v_caches_{head_idx}"
74+
)
75+
76+
77+
class SDPASha(torch.nn.Module):
78+
79+
def __init__(
80+
self,
81+
max_batch_size: int,
82+
max_seq_length: int,
83+
n_heads: int,
84+
n_rep: int,
85+
head_dim: int,
86+
dim: int,
87+
):
88+
super().__init__()
89+
self.head_dim = head_dim
90+
self.n_rep = n_rep
91+
self.dim = dim
92+
self.kv_cache = KVCacheSha(
93+
max_batch_size, max_seq_length, n_heads // n_rep, head_dim
94+
)
95+
self.scale_factor = math.sqrt(head_dim)
96+
97+
def forward(
98+
self,
99+
input_pos: torch.Tensor,
100+
qs: List[torch.Tensor],
101+
ks: List[torch.Tensor],
102+
vs: List[torch.Tensor],
103+
mask,
104+
):
105+
106+
transpose_ks = []
107+
for i in range(len(ks)):
108+
new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i)
109+
transpose_ks.append(new_k.transpose(-2, -1).contiguous())
110+
111+
output = []
112+
for i, q in enumerate(qs):
113+
cache_idx = i // self.n_rep
114+
_, v = self.kv_cache.get_cache(cache_idx)
115+
116+
attn_mask = mask[input_pos]
117+
118+
attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor
119+
attn_weight += attn_mask
120+
attn_weight = torch.softmax(attn_weight, dim=-1)
121+
output.append(attn_weight @ v.contiguous())
122+
123+
return torch.cat(output, dim=-1)
124+
125+
126+
class AttentionSha(nn.Module):
127+
def __init__(self, attention_mha: nn.Module):
128+
super().__init__()
129+
if not attention_mha.use_kv_cache:
130+
raise NotImplementedError("bert mode is not support")
131+
132+
self.n_heads = attention_mha.n_heads
133+
self.n_kv_heads = attention_mha.n_kv_heads
134+
self.n_rep = self.n_heads // self.n_kv_heads
135+
self.dim = attention_mha.dim
136+
self.max_batch_size = attention_mha.max_batch_size
137+
self.max_seq_len = attention_mha.max_seq_len
138+
self.head_dim = attention_mha.dim // self.n_heads
139+
self.SDPA = SDPASha(
140+
self.max_batch_size,
141+
self.max_seq_len,
142+
self.n_heads,
143+
self.n_rep,
144+
self.head_dim,
145+
self.dim,
146+
)
147+
self.wq = nn.ModuleList(
148+
[
149+
nn.Linear(self.dim, self.head_dim, bias=False)
150+
for _ in range(self.n_heads)
151+
]
152+
)
153+
self.wk = nn.ModuleList(
154+
[
155+
nn.Linear(self.dim, self.head_dim, bias=False)
156+
for _ in range(self.n_kv_heads)
157+
]
158+
)
159+
self.wv = nn.ModuleList(
160+
[
161+
nn.Linear(self.dim, self.head_dim, bias=False)
162+
for _ in range(self.n_kv_heads)
163+
]
164+
)
165+
166+
for i in range(self.n_heads):
167+
self.wq[i].weight.data.copy_(
168+
attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
169+
)
170+
for i in range(self.n_kv_heads):
171+
self.wk[i].weight.data.copy_(
172+
attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
173+
)
174+
self.wv[i].weight.data.copy_(
175+
attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim]
176+
)
177+
self.wo = attention_mha.wo
178+
179+
causal_mask = torch.tril(
180+
torch.ones(
181+
self.max_seq_len,
182+
self.max_seq_len,
183+
dtype=torch.bool,
184+
device="cpu",
185+
)
186+
)
187+
self.register_buffer("mask", causal_mask, persistent=False)
188+
189+
def forward(
190+
self,
191+
x: torch.Tensor,
192+
freqs_cos: torch.Tensor,
193+
freqs_sin: torch.Tensor,
194+
input_pos: Optional[torch.Tensor] = None,
195+
):
196+
# QKV
197+
q = [wq(x) for wq in self.wq]
198+
k = [wk(x) for wk in self.wk]
199+
v = [wv(x) for wv in self.wv]
200+
for i in range(len(q)):
201+
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
202+
for i in range(len(k)):
203+
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin)
204+
205+
output = self.SDPA(input_pos, q, k, v, self.mask)
206+
return self.wo(output)
207+
208+
209+
def replace_attention_to_attention_sha(module: torch.nn.Module):
210+
for name, child in module.named_children():
211+
if isinstance(child, Attention):
212+
setattr(
213+
module,
214+
name,
215+
AttentionSha(child),
216+
)
217+
else:
218+
replace_attention_to_attention_sha(child)
219+
return module

0 commit comments

Comments
 (0)