|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-unsafe |
| 8 | + |
| 9 | +# Example script for exporting Llama2 to flatbuffer |
| 10 | + |
| 11 | +import math |
| 12 | +from typing import List, Optional, Tuple |
| 13 | + |
| 14 | +import torch |
| 15 | +from executorch.examples.models.llama.llama_transformer import Attention |
| 16 | +from torch import nn |
| 17 | + |
| 18 | + |
| 19 | +def apply_rotary_emb_single( |
| 20 | + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor |
| 21 | +) -> torch.Tensor: |
| 22 | + x_r, x_i = x[..., ::2], x[..., 1::2] |
| 23 | + |
| 24 | + x_out_r = x_r * freqs_cos - x_i * freqs_sin |
| 25 | + x_out_i = x_r * freqs_sin + x_i * freqs_cos |
| 26 | + |
| 27 | + x_out = torch.cat([x_out_r, x_out_i], dim=-1) |
| 28 | + return x_out |
| 29 | + |
| 30 | + |
| 31 | +class KVCacheSHA(torch.nn.Module): |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + max_batch_size: int, |
| 35 | + max_seq_length: int, |
| 36 | + n_heads: int, |
| 37 | + head_dim: int, |
| 38 | + dtype=torch.float32, |
| 39 | + ): |
| 40 | + super().__init__() |
| 41 | + |
| 42 | + # a buffer per head |
| 43 | + cache_shape = (max_batch_size, max_seq_length, head_dim) |
| 44 | + for i in range(n_heads): |
| 45 | + self.register_buffer( |
| 46 | + f"past_k_caches_{i}", |
| 47 | + torch.zeros(cache_shape, dtype=dtype, device="cpu"), |
| 48 | + persistent=False, |
| 49 | + ) |
| 50 | + self.register_buffer( |
| 51 | + f"past_v_caches_{i}", |
| 52 | + torch.zeros(cache_shape, dtype=dtype, device="cpu"), |
| 53 | + persistent=False, |
| 54 | + ) |
| 55 | + |
| 56 | + def update( |
| 57 | + self, |
| 58 | + input_pos: torch.Tensor, |
| 59 | + k_val: torch.Tensor, |
| 60 | + v_val: torch.Tensor, |
| 61 | + cache_idx: int, |
| 62 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 63 | + new_k = torch.ops.aten.index_put_( |
| 64 | + getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val |
| 65 | + ) |
| 66 | + new_v = torch.ops.aten.index_put_( |
| 67 | + getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val |
| 68 | + ) |
| 69 | + return new_k, new_v |
| 70 | + |
| 71 | + def get_cache(self, head_idx): |
| 72 | + return getattr(self, f"past_k_caches_{head_idx}"), getattr( |
| 73 | + self, f"past_v_caches_{head_idx}" |
| 74 | + ) |
| 75 | + |
| 76 | + |
| 77 | +class SDPASHA(torch.nn.Module): |
| 78 | + |
| 79 | + def __init__( |
| 80 | + self, |
| 81 | + max_batch_size: int, |
| 82 | + max_seq_length: int, |
| 83 | + n_heads: int, |
| 84 | + n_rep: int, |
| 85 | + head_dim: int, |
| 86 | + dim: int, |
| 87 | + ): |
| 88 | + super().__init__() |
| 89 | + self.head_dim = head_dim |
| 90 | + self.n_rep = n_rep |
| 91 | + self.dim = dim |
| 92 | + self.kv_cache = KVCacheSHA( |
| 93 | + max_batch_size, max_seq_length, n_heads // n_rep, head_dim |
| 94 | + ) |
| 95 | + self.scale_factor = math.sqrt(head_dim) |
| 96 | + |
| 97 | + def forward( |
| 98 | + self, |
| 99 | + input_pos: torch.Tensor, |
| 100 | + qs: List[torch.Tensor], |
| 101 | + ks: List[torch.Tensor], |
| 102 | + vs: List[torch.Tensor], |
| 103 | + mask, |
| 104 | + ): |
| 105 | + |
| 106 | + transpose_ks = [] |
| 107 | + for i in range(len(ks)): |
| 108 | + new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i) |
| 109 | + transpose_ks.append(new_k.transpose(-2, -1).contiguous()) |
| 110 | + |
| 111 | + output = [] |
| 112 | + for i, q in enumerate(qs): |
| 113 | + cache_idx = i // self.n_rep |
| 114 | + _, v = self.kv_cache.get_cache(cache_idx) |
| 115 | + |
| 116 | + attn_mask = mask[input_pos] |
| 117 | + |
| 118 | + attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor |
| 119 | + attn_weight += attn_mask |
| 120 | + attn_weight = torch.softmax(attn_weight, dim=-1) |
| 121 | + output.append(attn_weight @ v.contiguous()) |
| 122 | + |
| 123 | + return torch.cat(output, dim=-1) |
| 124 | + |
| 125 | + |
| 126 | +class AttentionSHA(nn.Module): |
| 127 | + def __init__(self, attention_mha: nn.Module): |
| 128 | + super().__init__() |
| 129 | + if not attention_mha.use_kv_cache: |
| 130 | + raise NotImplementedError("bert mode is not support") |
| 131 | + |
| 132 | + self.n_heads = attention_mha.n_heads |
| 133 | + self.n_kv_heads = attention_mha.n_kv_heads |
| 134 | + self.n_rep = self.n_heads // self.n_kv_heads |
| 135 | + self.dim = attention_mha.dim |
| 136 | + self.max_batch_size = attention_mha.max_batch_size |
| 137 | + self.max_seq_len = attention_mha.max_seq_len |
| 138 | + self.head_dim = attention_mha.dim // self.n_heads |
| 139 | + self.SDPA = SDPASHA( |
| 140 | + self.max_batch_size, |
| 141 | + self.max_seq_len, |
| 142 | + self.n_heads, |
| 143 | + self.n_rep, |
| 144 | + self.head_dim, |
| 145 | + self.dim, |
| 146 | + ) |
| 147 | + self.wq = nn.ModuleList( |
| 148 | + [ |
| 149 | + nn.Linear(self.dim, self.head_dim, bias=False) |
| 150 | + for _ in range(self.n_heads) |
| 151 | + ] |
| 152 | + ) |
| 153 | + self.wk = nn.ModuleList( |
| 154 | + [ |
| 155 | + nn.Linear(self.dim, self.head_dim, bias=False) |
| 156 | + for _ in range(self.n_kv_heads) |
| 157 | + ] |
| 158 | + ) |
| 159 | + self.wv = nn.ModuleList( |
| 160 | + [ |
| 161 | + nn.Linear(self.dim, self.head_dim, bias=False) |
| 162 | + for _ in range(self.n_kv_heads) |
| 163 | + ] |
| 164 | + ) |
| 165 | + |
| 166 | + for i in range(self.n_heads): |
| 167 | + self.wq[i].weight.data.copy_( |
| 168 | + attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] |
| 169 | + ) |
| 170 | + for i in range(self.n_kv_heads): |
| 171 | + self.wk[i].weight.data.copy_( |
| 172 | + attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] |
| 173 | + ) |
| 174 | + self.wv[i].weight.data.copy_( |
| 175 | + attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] |
| 176 | + ) |
| 177 | + self.wo = attention_mha.wo |
| 178 | + |
| 179 | + causal_mask = torch.tril( |
| 180 | + torch.ones( |
| 181 | + self.max_seq_len, |
| 182 | + self.max_seq_len, |
| 183 | + dtype=torch.bool, |
| 184 | + device="cpu", |
| 185 | + ) |
| 186 | + ) |
| 187 | + self.register_buffer("mask", causal_mask, persistent=False) |
| 188 | + |
| 189 | + def forward( |
| 190 | + self, |
| 191 | + x: torch.Tensor, |
| 192 | + freqs_cos: torch.Tensor, |
| 193 | + freqs_sin: torch.Tensor, |
| 194 | + input_pos: Optional[torch.Tensor] = None, |
| 195 | + ): |
| 196 | + # QKV |
| 197 | + q = [wq(x) for wq in self.wq] |
| 198 | + k = [wk(x) for wk in self.wk] |
| 199 | + v = [wv(x) for wv in self.wv] |
| 200 | + for i in range(len(q)): |
| 201 | + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) |
| 202 | + for i in range(len(k)): |
| 203 | + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin) |
| 204 | + |
| 205 | + output = self.SDPA(input_pos, q, k, v, self.mask) |
| 206 | + return self.wo(output) |
| 207 | + |
| 208 | + |
| 209 | +def replace_attention_to_attention_sha(module: torch.nn.Module): |
| 210 | + for name, child in module.named_children(): |
| 211 | + if isinstance(child, Attention): |
| 212 | + setattr( |
| 213 | + module, |
| 214 | + name, |
| 215 | + AttentionSHA(child), |
| 216 | + ) |
| 217 | + else: |
| 218 | + replace_attention_to_attention_sha(child) |
| 219 | + return module |
0 commit comments