Skip to content

Commit 2fd273d

Browse files
sxufacebook-github-bot
authored andcommitted
Static attention Python I/O manager (#11763)
Summary: Pull Request resolved: #11763 Add a helper class to simplify cache and mask management. Useful to modeling for quantization and evaluation. Also will be used to implement lookahead decoding. Differential Revision: D76844656
1 parent 830631d commit 2fd273d

File tree

2 files changed

+197
-56
lines changed

2 files changed

+197
-56
lines changed

examples/models/llama/static_attention.py

Lines changed: 191 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
2-
from typing import Dict, Optional, Tuple
2+
from collections import deque
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
34

45
import torch
56
import torch.nn as nn
@@ -47,29 +48,39 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str:
4748
return f"l{layer_id},h{head_id}"
4849

4950
@staticmethod
50-
def apply_update(cache, update, pos, style, transpose=False):
51+
def apply_update(
52+
cache, update, pos, style, transpose=False, update_pos=0, update_len=None
53+
):
5154
"""
5255
After inference, update the cache state for next iteration. The runtime needs to
5356
implement the same operation.
5457
"""
5558
if style == "shift_pointer":
5659
if transpose:
57-
update_len = update.size(-1)
60+
update_len = update_len or update.size(-1)
5861
updated = torch.roll(cache, -update_len, -1)
59-
updated[:, :, -update_len:] = update
62+
updated[:, :, -update_len:] = update[
63+
:, :, update_pos : update_pos + update_len
64+
]
6065
else:
61-
update_len = update.size(-2)
66+
update_len = update_len or update.size(-2)
6267
updated = torch.roll(cache, -update_len, -2)
63-
updated[:, -update_len:, :] = update
68+
updated[:, -update_len:, :] = update[
69+
:, update_pos : update_pos + update_len, :
70+
]
6471

6572
if style == "smart_mask":
6673
updated = torch.clone(cache)
6774
if transpose:
68-
update_len = update.size(-1)
69-
updated[:, :, pos : pos + update_len] = update
75+
update_len = update_len or update.size(-1)
76+
updated[:, :, pos : pos + update_len] = update[
77+
:, :, update_pos : update_pos + update_len
78+
]
7079
else:
71-
update_len = update.size(-2)
72-
updated[:, pos : pos + update_len, :] = update
80+
update_len = update_len or update.size(-2)
81+
updated[:, pos : pos + update_len, :] = update[
82+
:, update_pos : update_pos + update_len, :
83+
]
7384

7485
return updated
7586

@@ -163,6 +174,176 @@ def unmask(self, new_unmasked_len):
163174
self.unmasked_len += new_unmasked_len
164175

165176

177+
class StaticAttentionIOManager:
178+
class NGramCache:
179+
def __init__(self, max_size):
180+
self.cache = deque()
181+
self.max_size = max_size
182+
183+
def add(self, x):
184+
if x in self.cache:
185+
return
186+
if len(self.cache) == self.max_size:
187+
self.cache.popleft()
188+
self.cache.append(x)
189+
190+
def __iter__(self):
191+
return iter(self.cache)
192+
193+
def __str__(self):
194+
return str(self.cache)
195+
196+
def __init__(
197+
self,
198+
config: ModelArgs,
199+
input_len: int,
200+
cache_len: int,
201+
style: str = "shift_pointer",
202+
mask_val: float = float("-inf"),
203+
):
204+
self.mask = StaticAttentionMask(
205+
input_len, cache_len, style=style, mask_val=mask_val
206+
)
207+
208+
rope = Rope(config)
209+
freqs = rope.get_freqs(None, config.max_seq_len)
210+
self.freqs_cos = freqs[0]
211+
self.freqs_sin = freqs[1]
212+
213+
self.k_caches = {
214+
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
215+
1, cache_len, config.head_dim
216+
)
217+
for layer_id in range(config.n_layers)
218+
for head_id in range(config.n_kv_heads)
219+
}
220+
self.v_caches = {
221+
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
222+
1, cache_len, config.head_dim
223+
)
224+
for layer_id in range(config.n_layers)
225+
for head_id in range(config.n_kv_heads)
226+
}
227+
228+
self.input_len = input_len
229+
self.cache_len = cache_len
230+
self.style = style
231+
self.mask_val = mask_val
232+
self.pos = 0
233+
self.cache_full = False
234+
235+
def reset(self):
236+
self.pos = 0
237+
self.cache_full = False
238+
self.mask.reset()
239+
240+
def prefill(
241+
self,
242+
model: Callable[..., Any],
243+
tokens: List[int],
244+
):
245+
if self.cache_full:
246+
raise RuntimeError("KV cache is full.")
247+
248+
self.mask.tensor[:, :, self.cache_len :] = torch.triu(
249+
torch.full((1, self.input_len, self.input_len), self.mask_val),
250+
diagonal=1,
251+
)
252+
253+
for i in range(0, len(tokens), self.input_len):
254+
x = tokens[i : i + self.input_len]
255+
last_len = len(x)
256+
y = self._run_once(model, tokens[i : i + self.input_len])[0]
257+
258+
if y.dim() == 2:
259+
return y
260+
261+
return y[:, last_len - 1 : last_len, :]
262+
263+
def decode(
264+
self,
265+
model: Callable[..., Any],
266+
init_token: int,
267+
n: int,
268+
stop_tokens: Optional[List[int]] = None,
269+
):
270+
if self.cache_full:
271+
raise RuntimeError("KV cache is full.")
272+
273+
self.mask.tensor[:, :, self.cache_len :] = torch.triu(
274+
torch.full((1, self.input_len, self.input_len), self.mask_val),
275+
diagonal=1,
276+
)
277+
278+
stop_tokens = stop_tokens or []
279+
new_tokens = [init_token]
280+
for _ in range(n):
281+
y = self._run_once(model, new_tokens[-1:])[0]
282+
new_tokens.append(y[:, :1, :].argmax().item())
283+
if new_tokens[-1] in stop_tokens:
284+
break
285+
286+
return new_tokens
287+
288+
def _run_once(
289+
self,
290+
model: Callable[..., Any],
291+
tokens: List[int],
292+
non_padded_len: Optional[int] = None,
293+
freqs_cos_override: Optional[torch.Tensor] = None,
294+
freqs_sin_override: Optional[torch.Tensor] = None,
295+
):
296+
n_tokens = len(tokens)
297+
if n_tokens < self.input_len:
298+
tokens += [0] * (self.input_len - n_tokens)
299+
tokens = torch.tensor([tokens], dtype=torch.int32)
300+
if freqs_cos_override is None:
301+
freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len]
302+
if freqs_sin_override is None:
303+
freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len]
304+
y, attn_updates = model(
305+
tokens,
306+
{
307+
"mask": self.mask.tensor,
308+
"freqs_cos_override": freqs_cos_override,
309+
"freqs_sin_override": freqs_sin_override,
310+
"in_cache_state": (self.k_caches, self.v_caches),
311+
},
312+
)
313+
non_padded_len = non_padded_len or n_tokens
314+
if self.pos + non_padded_len <= self.cache_len:
315+
self._update_states(attn_updates, 0, non_padded_len)
316+
else:
317+
self.cache_full = True
318+
319+
return y, attn_updates
320+
321+
def _update_states(self, attn_updates, update_pos, update_len):
322+
assert self.pos + update_len <= self.cache_len
323+
324+
self.mask.unmask(update_len)
325+
k_cache_updates, v_cache_updates = attn_updates["out_cache_state"]
326+
for cache_id, update in k_cache_updates.items():
327+
self.k_caches[cache_id] = StaticKVCache.apply_update(
328+
self.k_caches[cache_id],
329+
update,
330+
self.pos,
331+
style=self.style,
332+
update_pos=update_pos,
333+
update_len=update_len,
334+
)
335+
for cache_id, update in v_cache_updates.items():
336+
self.v_caches[cache_id] = StaticKVCache.apply_update(
337+
self.v_caches[cache_id],
338+
update,
339+
self.pos,
340+
style=self.style,
341+
update_pos=update_pos,
342+
update_len=update_len,
343+
)
344+
self.pos += update_len
345+
346+
166347
class _Rope(nn.Module):
167348
def __init__(self, use_hf_rope):
168349
super().__init__()

examples/models/llama/tests/test_static_attention.py

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import unittest
22

33
import torch
4-
from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions
4+
from executorch.examples.models.llama.attention import AttentionMHA
55
from executorch.examples.models.llama.llama_transformer import construct_transformer
66
from executorch.examples.models.llama.model_args import ModelArgs
77
from executorch.examples.models.llama.rope import Rope
88
from executorch.examples.models.llama.static_attention import (
99
StaticAttention,
10+
StaticAttentionIOManager,
1011
StaticAttentionMask,
1112
StaticKVCache,
1213
)
@@ -171,62 +172,21 @@ def test_within_transformer(self):
171172
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
172173

173174
x = torch.randint(config.vocab_size, (1, config.max_seq_len))
174-
rope = Rope(config)
175-
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
176175
expected = mha_transformer(x)
177176

178177
n_chunks = 3
179178
chunk_len = config.max_seq_len // n_chunks
180179
cache_len = config.max_seq_len - chunk_len
181180

182181
def test_with_style(style):
183-
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
184-
mask.tensor[:, :, cache_len:] = torch.triu(
185-
torch.full((1, chunk_len, chunk_len), float("-inf")),
186-
diagonal=1,
187-
)
188-
k_caches = {
189-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
190-
1, cache_len, config.head_dim
191-
)
192-
for layer_id in range(config.n_layers)
193-
for i in range(config.n_kv_heads)
194-
}
195-
v_caches = {
196-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
197-
1, cache_len, config.head_dim
198-
)
199-
for layer_id in range(config.n_layers)
200-
for i in range(config.n_kv_heads)
201-
}
182+
mgr = StaticAttentionIOManager(config, chunk_len, cache_len, style=style)
202183
ys = []
203184
for i in range(n_chunks):
204-
y_i, attn_update = static_transformer(
205-
x[:, i * chunk_len : (i + 1) * chunk_len],
206-
attn_options=ForwardOptions(
207-
mask=mask.tensor,
208-
freqs_cos_override=freqs_cos[
209-
i * chunk_len : (i + 1) * chunk_len
210-
],
211-
freqs_sin_override=freqs_sin[
212-
i * chunk_len : (i + 1) * chunk_len
213-
],
214-
in_cache_state=(k_caches, v_caches),
215-
out_cache_state=({}, {}),
216-
),
185+
y_i = mgr.prefill(
186+
static_transformer,
187+
x[0][i * chunk_len : (i + 1) * chunk_len].tolist(),
217188
)
218189
ys.append(y_i)
219-
mask.unmask(chunk_len)
220-
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
221-
if i < n_chunks - 1:
222-
for cache_id, update in k_cache_updates.items():
223-
k_caches[cache_id] = StaticKVCache.apply_update(
224-
k_caches[cache_id], update, pos=chunk_len * i, style=style
225-
)
226-
for cache_id, update in v_cache_updates.items():
227-
v_caches[cache_id] = StaticKVCache.apply_update(
228-
v_caches[cache_id], update, pos=chunk_len * i, style=style
229-
)
230190

231191
self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())
232192

0 commit comments

Comments
 (0)