|
1 | 1 | from abc import ABC, abstractmethod
|
2 |
| -from typing import Dict, Optional, Tuple |
| 2 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | import torch.nn as nn
|
@@ -47,29 +47,39 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str:
|
47 | 47 | return f"l{layer_id},h{head_id}"
|
48 | 48 |
|
49 | 49 | @staticmethod
|
50 |
| - def apply_update(cache, update, pos, style, transpose=False): |
| 50 | + def apply_update( |
| 51 | + cache, update, pos, style, transpose=False, update_pos=0, update_len=None |
| 52 | + ): |
51 | 53 | """
|
52 | 54 | After inference, update the cache state for next iteration. The runtime needs to
|
53 | 55 | implement the same operation.
|
54 | 56 | """
|
55 | 57 | if style == "shift_pointer":
|
56 | 58 | if transpose:
|
57 |
| - update_len = update.size(-1) |
| 59 | + update_len = update_len or update.size(-1) |
58 | 60 | updated = torch.roll(cache, -update_len, -1)
|
59 |
| - updated[:, :, -update_len:] = update |
| 61 | + updated[:, :, -update_len:] = update[ |
| 62 | + :, :, update_pos : update_pos + update_len |
| 63 | + ] |
60 | 64 | else:
|
61 |
| - update_len = update.size(-2) |
| 65 | + update_len = update_len or update.size(-2) |
62 | 66 | updated = torch.roll(cache, -update_len, -2)
|
63 |
| - updated[:, -update_len:, :] = update |
| 67 | + updated[:, -update_len:, :] = update[ |
| 68 | + :, update_pos : update_pos + update_len, : |
| 69 | + ] |
64 | 70 |
|
65 | 71 | if style == "smart_mask":
|
66 | 72 | updated = torch.clone(cache)
|
67 | 73 | if transpose:
|
68 |
| - update_len = update.size(-1) |
69 |
| - updated[:, :, pos : pos + update_len] = update |
| 74 | + update_len = update_len or update.size(-1) |
| 75 | + updated[:, :, pos : pos + update_len] = update[ |
| 76 | + :, :, update_pos : update_pos + update_len |
| 77 | + ] |
70 | 78 | else:
|
71 |
| - update_len = update.size(-2) |
72 |
| - updated[:, pos : pos + update_len, :] = update |
| 79 | + update_len = update_len or update.size(-2) |
| 80 | + updated[:, pos : pos + update_len, :] = update[ |
| 81 | + :, update_pos : update_pos + update_len, : |
| 82 | + ] |
73 | 83 |
|
74 | 84 | return updated
|
75 | 85 |
|
@@ -163,6 +173,164 @@ def unmask(self, new_unmasked_len):
|
163 | 173 | self.unmasked_len += new_unmasked_len
|
164 | 174 |
|
165 | 175 |
|
| 176 | +class StaticAttentionIOManager: |
| 177 | + def __init__( |
| 178 | + self, |
| 179 | + config: ModelArgs, |
| 180 | + input_len: int, |
| 181 | + cache_len: int, |
| 182 | + style: str = "shift_pointer", |
| 183 | + mask_val: float = float("-inf"), |
| 184 | + ): |
| 185 | + self.mask = StaticAttentionMask( |
| 186 | + input_len, cache_len, style=style, mask_val=mask_val |
| 187 | + ) |
| 188 | + |
| 189 | + rope = Rope(config) |
| 190 | + freqs = rope.get_freqs(None, config.max_seq_len) |
| 191 | + self.freqs_cos = freqs[0] |
| 192 | + self.freqs_sin = freqs[1] |
| 193 | + |
| 194 | + self.k_caches = { |
| 195 | + StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( |
| 196 | + 1, cache_len, config.head_dim |
| 197 | + ) |
| 198 | + for layer_id in range(config.n_layers) |
| 199 | + for head_id in range(config.n_kv_heads) |
| 200 | + } |
| 201 | + self.v_caches = { |
| 202 | + StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( |
| 203 | + 1, cache_len, config.head_dim |
| 204 | + ) |
| 205 | + for layer_id in range(config.n_layers) |
| 206 | + for head_id in range(config.n_kv_heads) |
| 207 | + } |
| 208 | + |
| 209 | + self.config = config |
| 210 | + self.input_len = input_len |
| 211 | + self.cache_len = cache_len |
| 212 | + self.style = style |
| 213 | + self.mask_val = mask_val |
| 214 | + self.pos = 0 |
| 215 | + self.cache_full = False |
| 216 | + |
| 217 | + def reset(self): |
| 218 | + self.pos = 0 |
| 219 | + self.cache_full = False |
| 220 | + self.mask.reset() |
| 221 | + |
| 222 | + def prefill( |
| 223 | + self, |
| 224 | + model: Callable[..., Any], |
| 225 | + tokens: List[int], |
| 226 | + ) -> torch.Tensor: |
| 227 | + if self.cache_full: |
| 228 | + raise RuntimeError("KV cache is full.") |
| 229 | + |
| 230 | + self.mask.tensor[:, :, self.cache_len :] = torch.triu( |
| 231 | + torch.full((1, self.input_len, self.input_len), self.mask_val), |
| 232 | + diagonal=1, |
| 233 | + ) |
| 234 | + |
| 235 | + logits = None |
| 236 | + all_logits = None |
| 237 | + for i in range(0, len(tokens), self.input_len): |
| 238 | + logits = self._run_once(model, tokens[i : i + self.input_len])[0] |
| 239 | + if self.config.generate_full_logits: |
| 240 | + if all_logits is None: |
| 241 | + all_logits = logits |
| 242 | + else: |
| 243 | + all_logits = torch.cat([all_logits, logits], dim=1) |
| 244 | + |
| 245 | + if self.config.generate_full_logits: |
| 246 | + return all_logits[:, : len(tokens), :] |
| 247 | + |
| 248 | + return logits |
| 249 | + |
| 250 | + def decode( |
| 251 | + self, |
| 252 | + model: Callable[..., Any], |
| 253 | + init_token: int, |
| 254 | + n: int, |
| 255 | + stop_tokens: Optional[List[int]] = None, |
| 256 | + ): |
| 257 | + if self.cache_full: |
| 258 | + raise RuntimeError("KV cache is full.") |
| 259 | + |
| 260 | + self.mask.tensor[:, :, self.cache_len :] = torch.triu( |
| 261 | + torch.full((1, self.input_len, self.input_len), self.mask_val), |
| 262 | + diagonal=1, |
| 263 | + ) |
| 264 | + |
| 265 | + stop_tokens = stop_tokens or [] |
| 266 | + new_tokens = [init_token] |
| 267 | + for _ in range(n): |
| 268 | + y = self._run_once(model, new_tokens[-1:])[0] |
| 269 | + new_tokens.append(y[:, :1, :].argmax().item()) |
| 270 | + if new_tokens[-1] in stop_tokens: |
| 271 | + break |
| 272 | + |
| 273 | + return new_tokens |
| 274 | + |
| 275 | + def _run_once( |
| 276 | + self, |
| 277 | + model: Callable[..., Any], |
| 278 | + tokens: List[int], |
| 279 | + non_padded_len: Optional[int] = None, |
| 280 | + freqs_cos_override: Optional[torch.Tensor] = None, |
| 281 | + freqs_sin_override: Optional[torch.Tensor] = None, |
| 282 | + ): |
| 283 | + n_tokens = len(tokens) |
| 284 | + if n_tokens < self.input_len: |
| 285 | + tokens += [0] * (self.input_len - n_tokens) |
| 286 | + tokens = torch.tensor([tokens], dtype=torch.int32) |
| 287 | + if freqs_cos_override is None: |
| 288 | + freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len] |
| 289 | + if freqs_sin_override is None: |
| 290 | + freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len] |
| 291 | + y, attn_updates = model( |
| 292 | + tokens, |
| 293 | + { |
| 294 | + "mask": self.mask.tensor, |
| 295 | + "freqs_cos_override": freqs_cos_override, |
| 296 | + "freqs_sin_override": freqs_sin_override, |
| 297 | + "in_cache_state": (self.k_caches, self.v_caches), |
| 298 | + }, |
| 299 | + ) |
| 300 | + non_padded_len = non_padded_len or n_tokens |
| 301 | + if self.pos + non_padded_len <= self.cache_len: |
| 302 | + self._update_states(attn_updates, 0, non_padded_len) |
| 303 | + else: |
| 304 | + self.cache_full = True |
| 305 | + |
| 306 | + return y, attn_updates |
| 307 | + |
| 308 | + def _update_states(self, attn_updates, update_pos, update_len): |
| 309 | + assert self.pos + update_len <= self.cache_len |
| 310 | + |
| 311 | + self.mask.unmask(update_len) |
| 312 | + k_cache_updates, v_cache_updates = attn_updates["out_cache_state"] |
| 313 | + for cache_id, update in k_cache_updates.items(): |
| 314 | + self.k_caches[cache_id] = StaticKVCache.apply_update( |
| 315 | + self.k_caches[cache_id], |
| 316 | + update, |
| 317 | + self.pos, |
| 318 | + style=self.style, |
| 319 | + update_pos=update_pos, |
| 320 | + update_len=update_len, |
| 321 | + ) |
| 322 | + for cache_id, update in v_cache_updates.items(): |
| 323 | + self.v_caches[cache_id] = StaticKVCache.apply_update( |
| 324 | + self.v_caches[cache_id], |
| 325 | + update, |
| 326 | + self.pos, |
| 327 | + style=self.style, |
| 328 | + update_pos=update_pos, |
| 329 | + update_len=update_len, |
| 330 | + ) |
| 331 | + self.pos += update_len |
| 332 | + |
| 333 | + |
166 | 334 | class _Rope(nn.Module):
|
167 | 335 | def __init__(self, use_hf_rope):
|
168 | 336 | super().__init__()
|
|
0 commit comments