Skip to content

Commit 3953a8f

Browse files
sxufacebook-github-bot
authored andcommitted
Static attention Python I/O manager (#11763)
Summary: Pull Request resolved: #11763 Add a helper class to simplify cache and mask management. Useful to modeling for quantization and evaluation. Also will be used to implement lookahead decoding. Differential Revision: D76844656
1 parent 6b47a16 commit 3953a8f

File tree

2 files changed

+184
-56
lines changed

2 files changed

+184
-56
lines changed

examples/models/llama/static_attention.py

Lines changed: 178 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Dict, Optional, Tuple
2+
from typing import Any, Callable, Dict, List, Optional, Tuple
33

44
import torch
55
import torch.nn as nn
@@ -47,29 +47,39 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str:
4747
return f"l{layer_id},h{head_id}"
4848

4949
@staticmethod
50-
def apply_update(cache, update, pos, style, transpose=False):
50+
def apply_update(
51+
cache, update, pos, style, transpose=False, update_pos=0, update_len=None
52+
):
5153
"""
5254
After inference, update the cache state for next iteration. The runtime needs to
5355
implement the same operation.
5456
"""
5557
if style == "shift_pointer":
5658
if transpose:
57-
update_len = update.size(-1)
59+
update_len = update_len or update.size(-1)
5860
updated = torch.roll(cache, -update_len, -1)
59-
updated[:, :, -update_len:] = update
61+
updated[:, :, -update_len:] = update[
62+
:, :, update_pos : update_pos + update_len
63+
]
6064
else:
61-
update_len = update.size(-2)
65+
update_len = update_len or update.size(-2)
6266
updated = torch.roll(cache, -update_len, -2)
63-
updated[:, -update_len:, :] = update
67+
updated[:, -update_len:, :] = update[
68+
:, update_pos : update_pos + update_len, :
69+
]
6470

6571
if style == "smart_mask":
6672
updated = torch.clone(cache)
6773
if transpose:
68-
update_len = update.size(-1)
69-
updated[:, :, pos : pos + update_len] = update
74+
update_len = update_len or update.size(-1)
75+
updated[:, :, pos : pos + update_len] = update[
76+
:, :, update_pos : update_pos + update_len
77+
]
7078
else:
71-
update_len = update.size(-2)
72-
updated[:, pos : pos + update_len, :] = update
79+
update_len = update_len or update.size(-2)
80+
updated[:, pos : pos + update_len, :] = update[
81+
:, update_pos : update_pos + update_len, :
82+
]
7383

7484
return updated
7585

@@ -163,6 +173,164 @@ def unmask(self, new_unmasked_len):
163173
self.unmasked_len += new_unmasked_len
164174

165175

176+
class StaticAttentionIOManager:
177+
def __init__(
178+
self,
179+
config: ModelArgs,
180+
input_len: int,
181+
cache_len: int,
182+
style: str = "shift_pointer",
183+
mask_val: float = float("-inf"),
184+
):
185+
self.mask = StaticAttentionMask(
186+
input_len, cache_len, style=style, mask_val=mask_val
187+
)
188+
189+
rope = Rope(config)
190+
freqs = rope.get_freqs(None, config.max_seq_len)
191+
self.freqs_cos = freqs[0]
192+
self.freqs_sin = freqs[1]
193+
194+
self.k_caches = {
195+
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
196+
1, cache_len, config.head_dim
197+
)
198+
for layer_id in range(config.n_layers)
199+
for head_id in range(config.n_kv_heads)
200+
}
201+
self.v_caches = {
202+
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
203+
1, cache_len, config.head_dim
204+
)
205+
for layer_id in range(config.n_layers)
206+
for head_id in range(config.n_kv_heads)
207+
}
208+
209+
self.config = config
210+
self.input_len = input_len
211+
self.cache_len = cache_len
212+
self.style = style
213+
self.mask_val = mask_val
214+
self.pos = 0
215+
self.cache_full = False
216+
217+
def reset(self):
218+
self.pos = 0
219+
self.cache_full = False
220+
self.mask.reset()
221+
222+
def prefill(
223+
self,
224+
model: Callable[..., Any],
225+
tokens: List[int],
226+
) -> torch.Tensor:
227+
if self.cache_full:
228+
raise RuntimeError("KV cache is full.")
229+
230+
self.mask.tensor[:, :, self.cache_len :] = torch.triu(
231+
torch.full((1, self.input_len, self.input_len), self.mask_val),
232+
diagonal=1,
233+
)
234+
235+
logits = None
236+
all_logits = None
237+
for i in range(0, len(tokens), self.input_len):
238+
logits = self._run_once(model, tokens[i : i + self.input_len])[0]
239+
if self.config.generate_full_logits:
240+
if all_logits is None:
241+
all_logits = logits
242+
else:
243+
all_logits = torch.cat([all_logits, logits], dim=1)
244+
245+
if self.config.generate_full_logits:
246+
return all_logits[:, : len(tokens), :]
247+
248+
return logits
249+
250+
def decode(
251+
self,
252+
model: Callable[..., Any],
253+
init_token: int,
254+
n: int,
255+
stop_tokens: Optional[List[int]] = None,
256+
):
257+
if self.cache_full:
258+
raise RuntimeError("KV cache is full.")
259+
260+
self.mask.tensor[:, :, self.cache_len :] = torch.triu(
261+
torch.full((1, self.input_len, self.input_len), self.mask_val),
262+
diagonal=1,
263+
)
264+
265+
stop_tokens = stop_tokens or []
266+
new_tokens = [init_token]
267+
for _ in range(n):
268+
y = self._run_once(model, new_tokens[-1:])[0]
269+
new_tokens.append(y[:, :1, :].argmax().item())
270+
if new_tokens[-1] in stop_tokens:
271+
break
272+
273+
return new_tokens
274+
275+
def _run_once(
276+
self,
277+
model: Callable[..., Any],
278+
tokens: List[int],
279+
non_padded_len: Optional[int] = None,
280+
freqs_cos_override: Optional[torch.Tensor] = None,
281+
freqs_sin_override: Optional[torch.Tensor] = None,
282+
):
283+
n_tokens = len(tokens)
284+
if n_tokens < self.input_len:
285+
tokens += [0] * (self.input_len - n_tokens)
286+
tokens = torch.tensor([tokens], dtype=torch.int32)
287+
if freqs_cos_override is None:
288+
freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len]
289+
if freqs_sin_override is None:
290+
freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len]
291+
y, attn_updates = model(
292+
tokens,
293+
{
294+
"mask": self.mask.tensor,
295+
"freqs_cos_override": freqs_cos_override,
296+
"freqs_sin_override": freqs_sin_override,
297+
"in_cache_state": (self.k_caches, self.v_caches),
298+
},
299+
)
300+
non_padded_len = non_padded_len or n_tokens
301+
if self.pos + non_padded_len <= self.cache_len:
302+
self._update_states(attn_updates, 0, non_padded_len)
303+
else:
304+
self.cache_full = True
305+
306+
return y, attn_updates
307+
308+
def _update_states(self, attn_updates, update_pos, update_len):
309+
assert self.pos + update_len <= self.cache_len
310+
311+
self.mask.unmask(update_len)
312+
k_cache_updates, v_cache_updates = attn_updates["out_cache_state"]
313+
for cache_id, update in k_cache_updates.items():
314+
self.k_caches[cache_id] = StaticKVCache.apply_update(
315+
self.k_caches[cache_id],
316+
update,
317+
self.pos,
318+
style=self.style,
319+
update_pos=update_pos,
320+
update_len=update_len,
321+
)
322+
for cache_id, update in v_cache_updates.items():
323+
self.v_caches[cache_id] = StaticKVCache.apply_update(
324+
self.v_caches[cache_id],
325+
update,
326+
self.pos,
327+
style=self.style,
328+
update_pos=update_pos,
329+
update_len=update_len,
330+
)
331+
self.pos += update_len
332+
333+
166334
class _Rope(nn.Module):
167335
def __init__(self, use_hf_rope):
168336
super().__init__()

examples/models/llama/tests/test_static_attention.py

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import unittest
22

33
import torch
4-
from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions
4+
from executorch.examples.models.llama.attention import AttentionMHA
55
from executorch.examples.models.llama.llama_transformer import construct_transformer
66
from executorch.examples.models.llama.model_args import ModelArgs
77
from executorch.examples.models.llama.rope import Rope
88
from executorch.examples.models.llama.static_attention import (
99
StaticAttention,
10+
StaticAttentionIOManager,
1011
StaticAttentionMask,
1112
StaticKVCache,
1213
)
@@ -171,62 +172,21 @@ def test_within_transformer(self):
171172
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
172173

173174
x = torch.randint(config.vocab_size, (1, config.max_seq_len))
174-
rope = Rope(config)
175-
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
176175
expected = mha_transformer(x)
177176

178177
n_chunks = 3
179178
chunk_len = config.max_seq_len // n_chunks
180179
cache_len = config.max_seq_len - chunk_len
181180

182181
def test_with_style(style):
183-
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
184-
mask.tensor[:, :, cache_len:] = torch.triu(
185-
torch.full((1, chunk_len, chunk_len), float("-inf")),
186-
diagonal=1,
187-
)
188-
k_caches = {
189-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
190-
1, cache_len, config.head_dim
191-
)
192-
for layer_id in range(config.n_layers)
193-
for i in range(config.n_kv_heads)
194-
}
195-
v_caches = {
196-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
197-
1, cache_len, config.head_dim
198-
)
199-
for layer_id in range(config.n_layers)
200-
for i in range(config.n_kv_heads)
201-
}
182+
mgr = StaticAttentionIOManager(config, chunk_len, cache_len, style=style)
202183
ys = []
203184
for i in range(n_chunks):
204-
y_i, attn_update = static_transformer(
205-
x[:, i * chunk_len : (i + 1) * chunk_len],
206-
attn_options=ForwardOptions(
207-
mask=mask.tensor,
208-
freqs_cos_override=freqs_cos[
209-
i * chunk_len : (i + 1) * chunk_len
210-
],
211-
freqs_sin_override=freqs_sin[
212-
i * chunk_len : (i + 1) * chunk_len
213-
],
214-
in_cache_state=(k_caches, v_caches),
215-
out_cache_state=({}, {}),
216-
),
185+
y_i = mgr.prefill(
186+
static_transformer,
187+
x[0][i * chunk_len : (i + 1) * chunk_len].tolist(),
217188
)
218189
ys.append(y_i)
219-
mask.unmask(chunk_len)
220-
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
221-
if i < n_chunks - 1:
222-
for cache_id, update in k_cache_updates.items():
223-
k_caches[cache_id] = StaticKVCache.apply_update(
224-
k_caches[cache_id], update, pos=chunk_len * i, style=style
225-
)
226-
for cache_id, update in v_cache_updates.items():
227-
v_caches[cache_id] = StaticKVCache.apply_update(
228-
v_caches[cache_id], update, pos=chunk_len * i, style=style
229-
)
230190

231191
self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())
232192

0 commit comments

Comments
 (0)