|
7 | 7 | # Components for supporting Attention Sink. See
|
8 | 8 | # https://arxiv.org/abs/2309.17453 for more details about Attention Sink.
|
9 | 9 |
|
| 10 | +import types |
10 | 11 | from typing import Optional
|
11 | 12 |
|
12 | 13 | import torch
|
13 | 14 |
|
14 |
| -from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope |
| 15 | +from executorch.examples.models.llama.llama_transformer import ( |
| 16 | + Attention, |
| 17 | + KVCache, |
| 18 | + ModelArgs, |
| 19 | + Rope, |
| 20 | +) |
15 | 21 | from executorch.examples.models.llama.rope import (
|
16 | 22 | apply_rotary_emb_to_k,
|
17 | 23 | hf_apply_rotary_emb_to_k,
|
18 | 24 | )
|
| 25 | +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter |
19 | 26 |
|
20 | 27 |
|
21 | 28 | class RopeWithAttentionSink(Rope):
|
@@ -206,3 +213,112 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
|
206 | 213 | )
|
207 | 214 | self.position_shift -= num_to_evict # pyre-ignore [8]
|
208 | 215 | return self.position_shift
|
| 216 | + |
| 217 | + |
| 218 | +def attention_sink_forward( |
| 219 | + self, |
| 220 | + x: torch.Tensor, |
| 221 | + freqs_cos: torch.Tensor, |
| 222 | + freqs_sin: torch.Tensor, |
| 223 | + input_pos: Optional[torch.Tensor] = None, |
| 224 | +): |
| 225 | + assert self.use_kv_cache |
| 226 | + assert input_pos is not None |
| 227 | + |
| 228 | + bsz, seqlen, _ = x.shape |
| 229 | + |
| 230 | + # QKV |
| 231 | + q, k, v = self.wq(x), self.wk(x), self.wv(x) |
| 232 | + # We need view_copy elimination |
| 233 | + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 234 | + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
| 235 | + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
| 236 | + |
| 237 | + # Prepare for space in KV cache and get position shift |
| 238 | + position_shift = self.kv_cache.evict_tokens(input_pos, seqlen) |
| 239 | + |
| 240 | + # RoPE relative positional embeddings with shifted position in KV cache |
| 241 | + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) |
| 242 | + |
| 243 | + output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask) |
| 244 | + return self.wo(output) |
| 245 | + |
| 246 | + |
| 247 | +def _replace_rope( |
| 248 | + module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink |
| 249 | +): |
| 250 | + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: |
| 251 | + return isinstance(child, Rope) |
| 252 | + |
| 253 | + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: |
| 254 | + return rope_with_attention_sink |
| 255 | + |
| 256 | + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) |
| 257 | + |
| 258 | + |
| 259 | +def _replace_attention( |
| 260 | + module: torch.nn.Module, |
| 261 | + rope_with_attention_sink: RopeWithAttentionSink, |
| 262 | + sink_size: int, |
| 263 | + window_size: int, |
| 264 | + eviction_batch_size: int, |
| 265 | +): |
| 266 | + for _, child_module in module._modules.items(): |
| 267 | + if len(list(child_module.children())) > 0: # pyre-ignore [16] |
| 268 | + _replace_attention( |
| 269 | + module=child_module, # pyre-ignore [6] |
| 270 | + rope_with_attention_sink=rope_with_attention_sink, |
| 271 | + sink_size=sink_size, |
| 272 | + window_size=window_size, |
| 273 | + eviction_batch_size=eviction_batch_size, |
| 274 | + ) |
| 275 | + |
| 276 | + if isinstance(child_module, Attention): |
| 277 | + kv_cache = child_module.kv_cache |
| 278 | + kv_cache_with_attention_sink = KVCacheWithAttentionSink( |
| 279 | + n_heads=kv_cache.n_heads, |
| 280 | + head_dim=kv_cache.head_dim, |
| 281 | + transpose_cache=kv_cache.transpose_cache, |
| 282 | + enable_dynamic_shape=kv_cache.enable_dynamic_shape, |
| 283 | + rope=rope_with_attention_sink, |
| 284 | + max_batch_size=kv_cache.max_batch_size, |
| 285 | + window_size=window_size, |
| 286 | + sink_size=sink_size, |
| 287 | + eviction_batch_size=eviction_batch_size, |
| 288 | + dtype=kv_cache.k_cache.dtype, |
| 289 | + ) |
| 290 | + child_module.kv_cache = kv_cache_with_attention_sink |
| 291 | + child_module.SDPA.kv_cache = kv_cache_with_attention_sink |
| 292 | + child_module.forward = types.MethodType( # pyre-ignore |
| 293 | + attention_sink_forward, child_module |
| 294 | + ) |
| 295 | + |
| 296 | + |
| 297 | +def enable_attention_sink( |
| 298 | + module: torch.nn.Module, |
| 299 | + params: ModelArgs, |
| 300 | + sink_size: int, |
| 301 | + window_size: int, |
| 302 | + eviction_batch_size: int, |
| 303 | +) -> torch.nn.Module: |
| 304 | + """ |
| 305 | + Transform the model to be able to run inference with Attention Sink. |
| 306 | + There mainly three steps: |
| 307 | + - Replace Rope with RopeWithAttentionSink |
| 308 | + - Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward |
| 309 | + """ |
| 310 | + rope_with_attention_sink = RopeWithAttentionSink( |
| 311 | + params=params, |
| 312 | + window_size=window_size, |
| 313 | + sink_size=sink_size, |
| 314 | + eviction_batch_size=eviction_batch_size, |
| 315 | + ) |
| 316 | + _replace_rope(module, rope_with_attention_sink) |
| 317 | + _replace_attention( |
| 318 | + module=module, |
| 319 | + rope_with_attention_sink=rope_with_attention_sink, |
| 320 | + sink_size=sink_size, |
| 321 | + window_size=window_size, |
| 322 | + eviction_batch_size=eviction_batch_size, |
| 323 | + ) |
| 324 | + return module |
0 commit comments