|
1 | 1 | from abc import ABC, abstractmethod
|
2 |
| -from typing import Dict, Optional, Tuple |
| 2 | +from collections import deque |
| 3 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
3 | 4 |
|
4 | 5 | import torch
|
5 | 6 | import torch.nn as nn
|
@@ -47,29 +48,39 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str:
|
47 | 48 | return f"l{layer_id},h{head_id}"
|
48 | 49 |
|
49 | 50 | @staticmethod
|
50 |
| - def apply_update(cache, update, pos, style, transpose=False): |
| 51 | + def apply_update( |
| 52 | + cache, update, pos, style, transpose=False, update_pos=0, update_len=None |
| 53 | + ): |
51 | 54 | """
|
52 | 55 | After inference, update the cache state for next iteration. The runtime needs to
|
53 | 56 | implement the same operation.
|
54 | 57 | """
|
55 | 58 | if style == "shift_pointer":
|
56 | 59 | if transpose:
|
57 |
| - update_len = update.size(-1) |
| 60 | + update_len = update_len or update.size(-1) |
58 | 61 | updated = torch.roll(cache, -update_len, -1)
|
59 |
| - updated[:, :, -update_len:] = update |
| 62 | + updated[:, :, -update_len:] = update[ |
| 63 | + :, :, update_pos : update_pos + update_len |
| 64 | + ] |
60 | 65 | else:
|
61 |
| - update_len = update.size(-2) |
| 66 | + update_len = update_len or update.size(-2) |
62 | 67 | updated = torch.roll(cache, -update_len, -2)
|
63 |
| - updated[:, -update_len:, :] = update |
| 68 | + updated[:, -update_len:, :] = update[ |
| 69 | + :, update_pos : update_pos + update_len, : |
| 70 | + ] |
64 | 71 |
|
65 | 72 | if style == "smart_mask":
|
66 | 73 | updated = torch.clone(cache)
|
67 | 74 | if transpose:
|
68 |
| - update_len = update.size(-1) |
69 |
| - updated[:, :, pos : pos + update_len] = update |
| 75 | + update_len = update_len or update.size(-1) |
| 76 | + updated[:, :, pos : pos + update_len] = update[ |
| 77 | + :, :, update_pos : update_pos + update_len |
| 78 | + ] |
70 | 79 | else:
|
71 |
| - update_len = update.size(-2) |
72 |
| - updated[:, pos : pos + update_len, :] = update |
| 80 | + update_len = update_len or update.size(-2) |
| 81 | + updated[:, pos : pos + update_len, :] = update[ |
| 82 | + :, update_pos : update_pos + update_len, : |
| 83 | + ] |
73 | 84 |
|
74 | 85 | return updated
|
75 | 86 |
|
@@ -163,6 +174,176 @@ def unmask(self, new_unmasked_len):
|
163 | 174 | self.unmasked_len += new_unmasked_len
|
164 | 175 |
|
165 | 176 |
|
| 177 | +class StaticAttentionIOManager: |
| 178 | + class NGramCache: |
| 179 | + def __init__(self, max_size): |
| 180 | + self.cache = deque() |
| 181 | + self.max_size = max_size |
| 182 | + |
| 183 | + def add(self, x): |
| 184 | + if x in self.cache: |
| 185 | + return |
| 186 | + if len(self.cache) == self.max_size: |
| 187 | + self.cache.popleft() |
| 188 | + self.cache.append(x) |
| 189 | + |
| 190 | + def __iter__(self): |
| 191 | + return iter(self.cache) |
| 192 | + |
| 193 | + def __str__(self): |
| 194 | + return str(self.cache) |
| 195 | + |
| 196 | + def __init__( |
| 197 | + self, |
| 198 | + config: ModelArgs, |
| 199 | + input_len: int, |
| 200 | + cache_len: int, |
| 201 | + style: str = "shift_pointer", |
| 202 | + mask_val: float = float("-inf"), |
| 203 | + ): |
| 204 | + self.mask = StaticAttentionMask( |
| 205 | + input_len, cache_len, style=style, mask_val=mask_val |
| 206 | + ) |
| 207 | + |
| 208 | + rope = Rope(config) |
| 209 | + freqs = rope.get_freqs(None, config.max_seq_len) |
| 210 | + self.freqs_cos = freqs[0] |
| 211 | + self.freqs_sin = freqs[1] |
| 212 | + |
| 213 | + self.k_caches = { |
| 214 | + StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( |
| 215 | + 1, cache_len, config.head_dim |
| 216 | + ) |
| 217 | + for layer_id in range(config.n_layers) |
| 218 | + for head_id in range(config.n_kv_heads) |
| 219 | + } |
| 220 | + self.v_caches = { |
| 221 | + StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( |
| 222 | + 1, cache_len, config.head_dim |
| 223 | + ) |
| 224 | + for layer_id in range(config.n_layers) |
| 225 | + for head_id in range(config.n_kv_heads) |
| 226 | + } |
| 227 | + |
| 228 | + self.input_len = input_len |
| 229 | + self.cache_len = cache_len |
| 230 | + self.style = style |
| 231 | + self.mask_val = mask_val |
| 232 | + self.pos = 0 |
| 233 | + self.cache_full = False |
| 234 | + |
| 235 | + def reset(self): |
| 236 | + self.pos = 0 |
| 237 | + self.cache_full = False |
| 238 | + self.mask.reset() |
| 239 | + |
| 240 | + def prefill( |
| 241 | + self, |
| 242 | + model: Callable[..., Any], |
| 243 | + tokens: List[int], |
| 244 | + ): |
| 245 | + if self.cache_full: |
| 246 | + raise RuntimeError("KV cache is full.") |
| 247 | + |
| 248 | + self.mask.tensor[:, :, self.cache_len :] = torch.triu( |
| 249 | + torch.full((1, self.input_len, self.input_len), self.mask_val), |
| 250 | + diagonal=1, |
| 251 | + ) |
| 252 | + |
| 253 | + for i in range(0, len(tokens), self.input_len): |
| 254 | + x = tokens[i : i + self.input_len] |
| 255 | + last_len = len(x) |
| 256 | + y = self._run_once(model, tokens[i : i + self.input_len])[0] |
| 257 | + |
| 258 | + if y.dim() == 2: |
| 259 | + return y |
| 260 | + |
| 261 | + return y[:, last_len - 1 : last_len, :] |
| 262 | + |
| 263 | + def decode( |
| 264 | + self, |
| 265 | + model: Callable[..., Any], |
| 266 | + init_token: int, |
| 267 | + n: int, |
| 268 | + stop_tokens: Optional[List[int]] = None, |
| 269 | + ): |
| 270 | + if self.cache_full: |
| 271 | + raise RuntimeError("KV cache is full.") |
| 272 | + |
| 273 | + self.mask.tensor[:, :, self.cache_len :] = torch.triu( |
| 274 | + torch.full((1, self.input_len, self.input_len), self.mask_val), |
| 275 | + diagonal=1, |
| 276 | + ) |
| 277 | + |
| 278 | + stop_tokens = stop_tokens or [] |
| 279 | + new_tokens = [init_token] |
| 280 | + for _ in range(n): |
| 281 | + y = self._run_once(model, new_tokens[-1:])[0] |
| 282 | + new_tokens.append(y[:, :1, :].argmax().item()) |
| 283 | + if new_tokens[-1] in stop_tokens: |
| 284 | + break |
| 285 | + |
| 286 | + return new_tokens |
| 287 | + |
| 288 | + def _run_once( |
| 289 | + self, |
| 290 | + model: Callable[..., Any], |
| 291 | + tokens: List[int], |
| 292 | + non_padded_len: Optional[int] = None, |
| 293 | + freqs_cos_override: Optional[torch.Tensor] = None, |
| 294 | + freqs_sin_override: Optional[torch.Tensor] = None, |
| 295 | + ): |
| 296 | + n_tokens = len(tokens) |
| 297 | + if n_tokens < self.input_len: |
| 298 | + tokens += [0] * (self.input_len - n_tokens) |
| 299 | + tokens = torch.tensor([tokens], dtype=torch.int32) |
| 300 | + if freqs_cos_override is None: |
| 301 | + freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len] |
| 302 | + if freqs_sin_override is None: |
| 303 | + freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len] |
| 304 | + y, attn_updates = model( |
| 305 | + tokens, |
| 306 | + { |
| 307 | + "mask": self.mask.tensor, |
| 308 | + "freqs_cos_override": freqs_cos_override, |
| 309 | + "freqs_sin_override": freqs_sin_override, |
| 310 | + "in_cache_state": (self.k_caches, self.v_caches), |
| 311 | + }, |
| 312 | + ) |
| 313 | + non_padded_len = non_padded_len or n_tokens |
| 314 | + if self.pos + non_padded_len <= self.cache_len: |
| 315 | + self._update_states(attn_updates, 0, non_padded_len) |
| 316 | + else: |
| 317 | + self.cache_full = True |
| 318 | + |
| 319 | + return y, attn_updates |
| 320 | + |
| 321 | + def _update_states(self, attn_updates, update_pos, update_len): |
| 322 | + assert self.pos + update_len <= self.cache_len |
| 323 | + |
| 324 | + self.mask.unmask(update_len) |
| 325 | + k_cache_updates, v_cache_updates = attn_updates["out_cache_state"] |
| 326 | + for cache_id, update in k_cache_updates.items(): |
| 327 | + self.k_caches[cache_id] = StaticKVCache.apply_update( |
| 328 | + self.k_caches[cache_id], |
| 329 | + update, |
| 330 | + self.pos, |
| 331 | + style=self.style, |
| 332 | + update_pos=update_pos, |
| 333 | + update_len=update_len, |
| 334 | + ) |
| 335 | + for cache_id, update in v_cache_updates.items(): |
| 336 | + self.v_caches[cache_id] = StaticKVCache.apply_update( |
| 337 | + self.v_caches[cache_id], |
| 338 | + update, |
| 339 | + self.pos, |
| 340 | + style=self.style, |
| 341 | + update_pos=update_pos, |
| 342 | + update_len=update_len, |
| 343 | + ) |
| 344 | + self.pos += update_len |
| 345 | + |
| 346 | + |
166 | 347 | class _Rope(nn.Module):
|
167 | 348 | def __init__(self, use_hf_rope):
|
168 | 349 | super().__init__()
|
|
0 commit comments