Skip to content

Static attention Python I/O manager #11763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 178 additions & 10 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -47,29 +47,39 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str:
return f"l{layer_id},h{head_id}"

@staticmethod
def apply_update(cache, update, pos, style, transpose=False):
def apply_update(
cache, update, pos, style, transpose=False, update_pos=0, update_len=None
):
"""
After inference, update the cache state for next iteration. The runtime needs to
implement the same operation.
"""
if style == "shift_pointer":
if transpose:
update_len = update.size(-1)
update_len = update_len or update.size(-1)
updated = torch.roll(cache, -update_len, -1)
updated[:, :, -update_len:] = update
updated[:, :, -update_len:] = update[
:, :, update_pos : update_pos + update_len
]
else:
update_len = update.size(-2)
update_len = update_len or update.size(-2)
updated = torch.roll(cache, -update_len, -2)
updated[:, -update_len:, :] = update
updated[:, -update_len:, :] = update[
:, update_pos : update_pos + update_len, :
]

if style == "smart_mask":
updated = torch.clone(cache)
if transpose:
update_len = update.size(-1)
updated[:, :, pos : pos + update_len] = update
update_len = update_len or update.size(-1)
updated[:, :, pos : pos + update_len] = update[
:, :, update_pos : update_pos + update_len
]
else:
update_len = update.size(-2)
updated[:, pos : pos + update_len, :] = update
update_len = update_len or update.size(-2)
updated[:, pos : pos + update_len, :] = update[
:, update_pos : update_pos + update_len, :
]

return updated

Expand Down Expand Up @@ -163,6 +173,164 @@ def unmask(self, new_unmasked_len):
self.unmasked_len += new_unmasked_len


class StaticAttentionIOManager:
def __init__(
self,
config: ModelArgs,
input_len: int,
cache_len: int,
style: str = "shift_pointer",
mask_val: float = float("-inf"),
):
self.mask = StaticAttentionMask(
input_len, cache_len, style=style, mask_val=mask_val
)

rope = Rope(config)
freqs = rope.get_freqs(None, config.max_seq_len)
self.freqs_cos = freqs[0]
self.freqs_sin = freqs[1]

self.k_caches = {
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
1, cache_len, config.head_dim
)
for layer_id in range(config.n_layers)
for head_id in range(config.n_kv_heads)
}
self.v_caches = {
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
1, cache_len, config.head_dim
)
for layer_id in range(config.n_layers)
for head_id in range(config.n_kv_heads)
}

self.config = config
self.input_len = input_len
self.cache_len = cache_len
self.style = style
self.mask_val = mask_val
self.pos = 0
self.cache_full = False

def reset(self):
self.pos = 0
self.cache_full = False
self.mask.reset()

def prefill(
self,
model: Callable[..., Any],
tokens: List[int],
) -> torch.Tensor:
if self.cache_full:
raise RuntimeError("KV cache is full.")

self.mask.tensor[:, :, self.cache_len :] = torch.triu(
torch.full((1, self.input_len, self.input_len), self.mask_val),
diagonal=1,
)

logits = None
all_logits = None
for i in range(0, len(tokens), self.input_len):
logits = self._run_once(model, tokens[i : i + self.input_len])[0]
if self.config.generate_full_logits:
if all_logits is None:
all_logits = logits
else:
all_logits = torch.cat([all_logits, logits], dim=1)

if self.config.generate_full_logits:
return all_logits[:, : len(tokens), :]

return logits

def decode(
self,
model: Callable[..., Any],
init_token: int,
n: int,
stop_tokens: Optional[List[int]] = None,
):
if self.cache_full:
raise RuntimeError("KV cache is full.")

self.mask.tensor[:, :, self.cache_len :] = torch.triu(
torch.full((1, self.input_len, self.input_len), self.mask_val),
diagonal=1,
)

stop_tokens = stop_tokens or []
new_tokens = [init_token]
for _ in range(n):
y = self._run_once(model, new_tokens[-1:])[0]
new_tokens.append(y[:, :1, :].argmax().item())
if new_tokens[-1] in stop_tokens:
break

return new_tokens

def _run_once(
self,
model: Callable[..., Any],
tokens: List[int],
non_padded_len: Optional[int] = None,
freqs_cos_override: Optional[torch.Tensor] = None,
freqs_sin_override: Optional[torch.Tensor] = None,
):
n_tokens = len(tokens)
if n_tokens < self.input_len:
tokens += [0] * (self.input_len - n_tokens)
tokens = torch.tensor([tokens], dtype=torch.int32)
if freqs_cos_override is None:
freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len]
if freqs_sin_override is None:
freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len]
y, attn_updates = model(
tokens,
{
"mask": self.mask.tensor,
"freqs_cos_override": freqs_cos_override,
"freqs_sin_override": freqs_sin_override,
"in_cache_state": (self.k_caches, self.v_caches),
},
)
non_padded_len = non_padded_len or n_tokens
if self.pos + non_padded_len <= self.cache_len:
self._update_states(attn_updates, 0, non_padded_len)
else:
self.cache_full = True

return y, attn_updates

def _update_states(self, attn_updates, update_pos, update_len):
assert self.pos + update_len <= self.cache_len

self.mask.unmask(update_len)
k_cache_updates, v_cache_updates = attn_updates["out_cache_state"]
for cache_id, update in k_cache_updates.items():
self.k_caches[cache_id] = StaticKVCache.apply_update(
self.k_caches[cache_id],
update,
self.pos,
style=self.style,
update_pos=update_pos,
update_len=update_len,
)
for cache_id, update in v_cache_updates.items():
self.v_caches[cache_id] = StaticKVCache.apply_update(
self.v_caches[cache_id],
update,
self.pos,
style=self.style,
update_pos=update_pos,
update_len=update_len,
)
self.pos += update_len


class _Rope(nn.Module):
def __init__(self, use_hf_rope):
super().__init__()
Expand Down
52 changes: 6 additions & 46 deletions examples/models/llama/tests/test_static_attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import unittest

import torch
from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions
from executorch.examples.models.llama.attention import AttentionMHA
from executorch.examples.models.llama.llama_transformer import construct_transformer
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.rope import Rope
from executorch.examples.models.llama.static_attention import (
StaticAttention,
StaticAttentionIOManager,
StaticAttentionMask,
StaticKVCache,
)
Expand Down Expand Up @@ -171,62 +172,21 @@ def test_within_transformer(self):
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)

x = torch.randint(config.vocab_size, (1, config.max_seq_len))
rope = Rope(config)
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
expected = mha_transformer(x)

n_chunks = 3
chunk_len = config.max_seq_len // n_chunks
cache_len = config.max_seq_len - chunk_len

def test_with_style(style):
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
mask.tensor[:, :, cache_len:] = torch.triu(
torch.full((1, chunk_len, chunk_len), float("-inf")),
diagonal=1,
)
k_caches = {
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
1, cache_len, config.head_dim
)
for layer_id in range(config.n_layers)
for i in range(config.n_kv_heads)
}
v_caches = {
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
1, cache_len, config.head_dim
)
for layer_id in range(config.n_layers)
for i in range(config.n_kv_heads)
}
mgr = StaticAttentionIOManager(config, chunk_len, cache_len, style=style)
ys = []
for i in range(n_chunks):
y_i, attn_update = static_transformer(
x[:, i * chunk_len : (i + 1) * chunk_len],
attn_options=ForwardOptions(
mask=mask.tensor,
freqs_cos_override=freqs_cos[
i * chunk_len : (i + 1) * chunk_len
],
freqs_sin_override=freqs_sin[
i * chunk_len : (i + 1) * chunk_len
],
in_cache_state=(k_caches, v_caches),
out_cache_state=({}, {}),
),
y_i = mgr.prefill(
static_transformer,
x[0][i * chunk_len : (i + 1) * chunk_len].tolist(),
)
ys.append(y_i)
mask.unmask(chunk_len)
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
if i < n_chunks - 1:
for cache_id, update in k_cache_updates.items():
k_caches[cache_id] = StaticKVCache.apply_update(
k_caches[cache_id], update, pos=chunk_len * i, style=style
)
for cache_id, update in v_cache_updates.items():
v_caches[cache_id] = StaticKVCache.apply_update(
v_caches[cache_id], update, pos=chunk_len * i, style=style
)

self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())

Expand Down
Loading