|
7 | 7 | # Components for supporting Attention Sink. See
|
8 | 8 | # https://arxiv.org/abs/2309.17453 for more details about Attention Sink.
|
9 | 9 |
|
| 10 | +import types |
| 11 | +from typing import Optional |
| 12 | + |
10 | 13 | import torch
|
11 | 14 |
|
12 |
| -from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope |
| 15 | +from executorch.examples.models.llama.llama_transformer import ( |
| 16 | + Attention, |
| 17 | + KVCache, |
| 18 | + ModelArgs, |
| 19 | + Rope, |
| 20 | +) |
13 | 21 | from executorch.examples.models.llama.rope import (
|
14 | 22 | apply_rotary_emb_to_k,
|
15 | 23 | hf_apply_rotary_emb_to_k,
|
16 | 24 | )
|
| 25 | +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter |
17 | 26 |
|
18 | 27 |
|
19 | 28 | class RopeWithAttentionSink(Rope):
|
@@ -167,3 +176,106 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
|
167 | 176 | )
|
168 | 177 | self.position_shift -= num_to_evict # pyre-ignore [8]
|
169 | 178 | return self.position_shift
|
| 179 | + |
| 180 | + |
| 181 | +def attention_sink_forward( |
| 182 | + self, |
| 183 | + x: torch.Tensor, |
| 184 | + input_pos: Optional[torch.Tensor] = None, |
| 185 | +): |
| 186 | + assert self.use_kv_cache |
| 187 | + assert input_pos is not None |
| 188 | + |
| 189 | + bsz, seqlen, _ = x.shape |
| 190 | + |
| 191 | + # QKV |
| 192 | + q, k, v = self.wq(x), self.wk(x), self.wv(x) |
| 193 | + # We need view_copy elimination |
| 194 | + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 195 | + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
| 196 | + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
| 197 | + |
| 198 | + # Prepare for space in KV cache and get position shift |
| 199 | + position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) |
| 200 | + |
| 201 | + shifted_position = input_pos + position_shift |
| 202 | + |
| 203 | + # RoPE relative positional embeddings with shifted position in KV cache |
| 204 | + q, k = self.rope.forward(q, k, shifted_position) |
| 205 | + |
| 206 | + output = self.SDPA(shifted_position, q, k, v, bsz, seqlen, self.mask) |
| 207 | + return self.wo(output) |
| 208 | + |
| 209 | + |
| 210 | +def _replace_rope( |
| 211 | + module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink |
| 212 | +): |
| 213 | + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: |
| 214 | + return isinstance(child, Rope) |
| 215 | + |
| 216 | + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: |
| 217 | + return rope_with_attention_sink |
| 218 | + |
| 219 | + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) |
| 220 | + |
| 221 | + |
| 222 | +def _replace_kv_cache( |
| 223 | + module: torch.nn.Module, |
| 224 | + rope_with_attention_sink: RopeWithAttentionSink, |
| 225 | + sink_size: int, |
| 226 | + window_size: int, |
| 227 | + eviction_batch_size: int, |
| 228 | +): |
| 229 | + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: |
| 230 | + return isinstance(child, KVCache) |
| 231 | + |
| 232 | + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: |
| 233 | + kv_cache_with_attention_sink = KVCacheWithAttentionSink( |
| 234 | + n_heads=child.n_heads, |
| 235 | + head_dim=child.head_dim, |
| 236 | + transpose_cache=child.transpose_cache, |
| 237 | + enable_dynamic_shape=child.enable_dynamic_shape, |
| 238 | + rope=rope_with_attention_sink, |
| 239 | + max_batch_size=child.max_batch_size, |
| 240 | + window_size=window_size, |
| 241 | + sink_size=sink_size, |
| 242 | + eviction_batch_size=eviction_batch_size, |
| 243 | + dtype=child.k_cache.dtype, |
| 244 | + ) |
| 245 | + return kv_cache_with_attention_sink |
| 246 | + |
| 247 | + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) |
| 248 | + |
| 249 | + |
| 250 | +def _replace_attention_forward(module: torch.nn.Module): |
| 251 | + for name, child_module in module._modules.items(): |
| 252 | + if len(list(child_module.children())) > 0: # pyre-ignore [16] |
| 253 | + _replace_attention_forward(child_module) # pyre-ignore [6] |
| 254 | + |
| 255 | + if isinstance(child_module, Attention): |
| 256 | + module._modules[name].forward = types.MethodType( # pyre-ignore |
| 257 | + attention_sink_forward, module._modules[name] |
| 258 | + ) |
| 259 | + |
| 260 | + |
| 261 | +def enable_attention_sink( |
| 262 | + module: torch.nn.Module, |
| 263 | + params: ModelArgs, |
| 264 | + sink_size: int = 4, |
| 265 | + window_size: int = 2044, |
| 266 | + eviction_batch_size: int = 1, |
| 267 | +) -> torch.nn.Module: |
| 268 | + """ |
| 269 | + Transform the model to be able to run inference with Attention Sink. |
| 270 | + There mainly three steps: |
| 271 | + - Replace Rope with RopeWithAttentionSink |
| 272 | + - Replace KVCache with KVCacheWithAttentionSink |
| 273 | + - Replace Attention's forward with attention_sink_forward |
| 274 | + """ |
| 275 | + rope_with_attention_sink = RopeWithAttentionSink(params=params) |
| 276 | + _replace_rope(module, rope_with_attention_sink) |
| 277 | + _replace_kv_cache( |
| 278 | + module, rope_with_attention_sink, sink_size, window_size, eviction_batch_size |
| 279 | + ) |
| 280 | + _replace_attention_forward(module) |
| 281 | + return module |
0 commit comments