Skip to content

Commit 812c645

Browse files
sxufacebook-github-bot
authored andcommitted
Static attention Python I/O manager (#11763)
Summary: Pull Request resolved: #11763 Add a helper class to simplify cache and mask management. Useful to modeling for quantization and evaluation. Also will be used to implement lookahead decoding. Differential Revision: D76844656
1 parent 7565342 commit 812c645

File tree

2 files changed

+203
-56
lines changed

2 files changed

+203
-56
lines changed

examples/models/llama/static_attention.py

Lines changed: 197 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
2-
from typing import Dict, Optional, Tuple
2+
from collections import defaultdict, deque
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
34

45
import torch
56
import torch.nn as nn
@@ -47,29 +48,39 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str:
4748
return f"l{layer_id},h{head_id}"
4849

4950
@staticmethod
50-
def apply_update(cache, update, pos, style, transpose=False):
51+
def apply_update(
52+
cache, update, pos, style, transpose=False, update_pos=0, update_len=None
53+
):
5154
"""
5255
After inference, update the cache state for next iteration. The runtime needs to
5356
implement the same operation.
5457
"""
5558
if style == "shift_pointer":
5659
if transpose:
57-
update_len = update.size(-1)
60+
update_len = update_len or update.size(-1)
5861
updated = torch.roll(cache, -update_len, -1)
59-
updated[:, :, -update_len:] = update
62+
updated[:, :, -update_len:] = update[
63+
:, :, update_pos : update_pos + update_len
64+
]
6065
else:
61-
update_len = update.size(-2)
66+
update_len = update_len or update.size(-2)
6267
updated = torch.roll(cache, -update_len, -2)
63-
updated[:, -update_len:, :] = update
68+
updated[:, -update_len:, :] = update[
69+
:, update_pos : update_pos + update_len, :
70+
]
6471

6572
if style == "smart_mask":
6673
updated = torch.clone(cache)
6774
if transpose:
68-
update_len = update.size(-1)
69-
updated[:, :, pos : pos + update_len] = update
75+
update_len = update_len or update.size(-1)
76+
updated[:, :, pos : pos + update_len] = update[
77+
:, :, update_pos : update_pos + update_len
78+
]
7079
else:
71-
update_len = update.size(-2)
72-
updated[:, pos : pos + update_len, :] = update
80+
update_len = update_len or update.size(-2)
81+
updated[:, pos : pos + update_len, :] = update[
82+
:, update_pos : update_pos + update_len, :
83+
]
7384

7485
return updated
7586

@@ -163,6 +174,182 @@ def unmask(self, new_unmasked_len):
163174
self.unmasked_len += new_unmasked_len
164175

165176

177+
class StaticAttentionIOManager:
178+
class NGramCache:
179+
def __init__(self, max_size):
180+
self.cache = deque()
181+
self.max_size = max_size
182+
183+
def add(self, x):
184+
if x in self.cache:
185+
return
186+
if len(self.cache) == self.max_size:
187+
self.cache.popleft()
188+
self.cache.append(x)
189+
190+
def __iter__(self):
191+
return iter(self.cache)
192+
193+
def __str__(self):
194+
return str(self.cache)
195+
196+
def __init__(
197+
self,
198+
config: ModelArgs,
199+
input_len: int,
200+
cache_len: int,
201+
style: str = "shift_pointer",
202+
mask_val: float = float("-inf"),
203+
):
204+
self.mask = StaticAttentionMask(
205+
input_len, cache_len, style=style, mask_val=mask_val
206+
)
207+
208+
rope = Rope(config)
209+
freqs = rope.get_freqs(None, config.max_seq_len)
210+
self.freqs_cos = freqs[0]
211+
self.freqs_sin = freqs[1]
212+
213+
self.k_caches = {
214+
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
215+
1, cache_len, config.head_dim
216+
)
217+
for layer_id in range(config.n_layers)
218+
for head_id in range(config.n_kv_heads)
219+
}
220+
self.v_caches = {
221+
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
222+
1, cache_len, config.head_dim
223+
)
224+
for layer_id in range(config.n_layers)
225+
for head_id in range(config.n_kv_heads)
226+
}
227+
228+
self.config = config
229+
self.input_len = input_len
230+
self.cache_len = cache_len
231+
self.style = style
232+
self.mask_val = mask_val
233+
self.pos = 0
234+
self.cache_full = False
235+
236+
def reset(self):
237+
self.pos = 0
238+
self.cache_full = False
239+
self.mask.reset()
240+
241+
def prefill(
242+
self,
243+
model: Callable[..., Any],
244+
tokens: List[int],
245+
) -> torch.Tensor:
246+
if self.cache_full:
247+
raise RuntimeError("KV cache is full.")
248+
249+
self.mask.tensor[:, :, self.cache_len :] = torch.triu(
250+
torch.full((1, self.input_len, self.input_len), self.mask_val),
251+
diagonal=1,
252+
)
253+
254+
logits = None
255+
all_logits = None
256+
for i in range(0, len(tokens), self.input_len):
257+
logits = self._run_once(model, tokens[i : i + self.input_len])[0]
258+
if self.config.generate_full_logits:
259+
if all_logits is None:
260+
all_logits = logits
261+
else:
262+
all_logits = torch.cat([all_logits, logits], dim=1)
263+
264+
if self.config.generate_full_logits:
265+
return all_logits[:, :len(tokens), :]
266+
267+
return logits
268+
269+
def decode(
270+
self,
271+
model: Callable[..., Any],
272+
init_token: int,
273+
n: int,
274+
stop_tokens: Optional[List[int]] = None,
275+
):
276+
if self.cache_full:
277+
raise RuntimeError("KV cache is full.")
278+
279+
self.mask.tensor[:, :, self.cache_len :] = torch.triu(
280+
torch.full((1, self.input_len, self.input_len), self.mask_val),
281+
diagonal=1,
282+
)
283+
284+
stop_tokens = stop_tokens or []
285+
new_tokens = [init_token]
286+
for _ in range(n):
287+
y = self._run_once(model, new_tokens[-1:])[0]
288+
new_tokens.append(y[:, :1, :].argmax().item())
289+
if new_tokens[-1] in stop_tokens:
290+
break
291+
292+
return new_tokens
293+
294+
def _run_once(
295+
self,
296+
model: Callable[..., Any],
297+
tokens: List[int],
298+
non_padded_len: Optional[int] = None,
299+
freqs_cos_override: Optional[torch.Tensor] = None,
300+
freqs_sin_override: Optional[torch.Tensor] = None,
301+
):
302+
n_tokens = len(tokens)
303+
if n_tokens < self.input_len:
304+
tokens += [0] * (self.input_len - n_tokens)
305+
tokens = torch.tensor([tokens], dtype=torch.int32)
306+
if freqs_cos_override is None:
307+
freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len]
308+
if freqs_sin_override is None:
309+
freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len]
310+
y, attn_updates = model(
311+
tokens,
312+
{
313+
"mask": self.mask.tensor,
314+
"freqs_cos_override": freqs_cos_override,
315+
"freqs_sin_override": freqs_sin_override,
316+
"in_cache_state": (self.k_caches, self.v_caches),
317+
},
318+
)
319+
non_padded_len = non_padded_len or n_tokens
320+
if self.pos + non_padded_len <= self.cache_len:
321+
self._update_states(attn_updates, 0, non_padded_len)
322+
else:
323+
self.cache_full = True
324+
325+
return y, attn_updates
326+
327+
def _update_states(self, attn_updates, update_pos, update_len):
328+
assert self.pos + update_len <= self.cache_len
329+
330+
self.mask.unmask(update_len)
331+
k_cache_updates, v_cache_updates = attn_updates["out_cache_state"]
332+
for cache_id, update in k_cache_updates.items():
333+
self.k_caches[cache_id] = StaticKVCache.apply_update(
334+
self.k_caches[cache_id],
335+
update,
336+
self.pos,
337+
style=self.style,
338+
update_pos=update_pos,
339+
update_len=update_len,
340+
)
341+
for cache_id, update in v_cache_updates.items():
342+
self.v_caches[cache_id] = StaticKVCache.apply_update(
343+
self.v_caches[cache_id],
344+
update,
345+
self.pos,
346+
style=self.style,
347+
update_pos=update_pos,
348+
update_len=update_len,
349+
)
350+
self.pos += update_len
351+
352+
166353
class _Rope(nn.Module):
167354
def __init__(self, use_hf_rope):
168355
super().__init__()

examples/models/llama/tests/test_static_attention.py

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import unittest
22

33
import torch
4-
from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions
4+
from executorch.examples.models.llama.attention import AttentionMHA
55
from executorch.examples.models.llama.llama_transformer import construct_transformer
66
from executorch.examples.models.llama.model_args import ModelArgs
77
from executorch.examples.models.llama.rope import Rope
88
from executorch.examples.models.llama.static_attention import (
99
StaticAttention,
10+
StaticAttentionIOManager,
1011
StaticAttentionMask,
1112
StaticKVCache,
1213
)
@@ -171,62 +172,21 @@ def test_within_transformer(self):
171172
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
172173

173174
x = torch.randint(config.vocab_size, (1, config.max_seq_len))
174-
rope = Rope(config)
175-
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
176175
expected = mha_transformer(x)
177176

178177
n_chunks = 3
179178
chunk_len = config.max_seq_len // n_chunks
180179
cache_len = config.max_seq_len - chunk_len
181180

182181
def test_with_style(style):
183-
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
184-
mask.tensor[:, :, cache_len:] = torch.triu(
185-
torch.full((1, chunk_len, chunk_len), float("-inf")),
186-
diagonal=1,
187-
)
188-
k_caches = {
189-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
190-
1, cache_len, config.head_dim
191-
)
192-
for layer_id in range(config.n_layers)
193-
for i in range(config.n_kv_heads)
194-
}
195-
v_caches = {
196-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
197-
1, cache_len, config.head_dim
198-
)
199-
for layer_id in range(config.n_layers)
200-
for i in range(config.n_kv_heads)
201-
}
182+
mgr = StaticAttentionIOManager(config, chunk_len, cache_len, style=style)
202183
ys = []
203184
for i in range(n_chunks):
204-
y_i, attn_update = static_transformer(
205-
x[:, i * chunk_len : (i + 1) * chunk_len],
206-
attn_options=ForwardOptions(
207-
mask=mask.tensor,
208-
freqs_cos_override=freqs_cos[
209-
i * chunk_len : (i + 1) * chunk_len
210-
],
211-
freqs_sin_override=freqs_sin[
212-
i * chunk_len : (i + 1) * chunk_len
213-
],
214-
in_cache_state=(k_caches, v_caches),
215-
out_cache_state=({}, {}),
216-
),
185+
y_i = mgr.prefill(
186+
static_transformer,
187+
x[0][i * chunk_len : (i + 1) * chunk_len].tolist(),
217188
)
218189
ys.append(y_i)
219-
mask.unmask(chunk_len)
220-
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
221-
if i < n_chunks - 1:
222-
for cache_id, update in k_cache_updates.items():
223-
k_caches[cache_id] = StaticKVCache.apply_update(
224-
k_caches[cache_id], update, pos=chunk_len * i, style=style
225-
)
226-
for cache_id, update in v_cache_updates.items():
227-
v_caches[cache_id] = StaticKVCache.apply_update(
228-
v_caches[cache_id], update, pos=chunk_len * i, style=style
229-
)
230190

231191
self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())
232192

0 commit comments

Comments
 (0)