Skip to content

Add smart mask style KVCache and mask #8463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,29 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str:
return f"l{layer_id},h{head_id}"

@staticmethod
def apply_update(cache, update, transpose=False):
def apply_update(cache, update, pos, style, transpose=False):
"""
After inference, update the cache state for next iteration. The runtime needs to
implement the same operation.
"""
if transpose:
update_len = update.size(-1)
updated = torch.roll(cache, -update_len, -1)
updated[:, :, -update_len:] = update
else:
update_len = update.size(-2)
updated = torch.roll(cache, -update_len, -2)
updated[:, -update_len:, :] = update
if style == "shift_pointer":
if transpose:
update_len = update.size(-1)
updated = torch.roll(cache, -update_len, -1)
updated[:, :, -update_len:] = update
else:
update_len = update.size(-2)
updated = torch.roll(cache, -update_len, -2)
updated[:, -update_len:, :] = update

if style == "smart_mask":
updated = torch.clone(cache)
if transpose:
update_len = update.size(-1)
updated[:, :, pos : pos + update_len] = update
else:
update_len = update.size(-2)
updated[:, pos : pos + update_len, :] = update

return updated

Expand Down Expand Up @@ -114,6 +124,44 @@ def update(
return all_data, (out_k_cache, out_v_cache)


class StaticAttentionMask:
def __init__(self, input_len, cache_len, style):
self.input_len = input_len
self.cache_len = cache_len
assert style in ("shift_pointer", "smart_mask")
self.style = style
self.unmasked_len = 0
self.tensor = torch.zeros(1, input_len, input_len + cache_len)
self.reset()

def reset(self):
self.unmasked_len = 0
self.tensor[:, :, : self.cache_len] = float("-inf")

def unmask(self, new_unmasked_len):
if new_unmasked_len <= 0:
return

if self.style == "shift_pointer":
self.tensor[
:,
:,
self.cache_len
- self.unmasked_len
- new_unmasked_len : self.cache_len
- self.unmasked_len,
] = 0

if self.style == "smart_mask":
self.tensor[
:,
:,
self.unmasked_len : self.unmasked_len + new_unmasked_len,
] = 0

self.unmasked_len += new_unmasked_len


class _Rope(nn.Module):
def __init__(self, use_hf_rope):
super().__init__()
Expand All @@ -135,7 +183,6 @@ def forward(
x_r, x_i = x[..., ::2], x[..., 1::2]
x_out_r = x_r * freqs_cos - x_i * freqs_sin
x_out_i = x_r * freqs_sin + x_i * freqs_cos

x_out = torch.cat([x_out_r, x_out_i], dim=-1)
return x_out

Expand Down
178 changes: 97 additions & 81 deletions examples/models/llama/tests/test_static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from executorch.examples.models.llama.rope import Rope
from executorch.examples.models.llama.static_attention import (
StaticAttention,
StaticAttentionMask,
StaticKVCache,
)

Expand Down Expand Up @@ -92,48 +93,54 @@ def test_with_cache(self):
n_chunks = 3
chunk_len = config.max_seq_len // n_chunks
cache_len = config.max_seq_len - chunk_len
mask = torch.zeros(1, chunk_len, cache_len + chunk_len)
mask[:, :, :cache_len] = float("-inf")
mask[:, :, cache_len:] = torch.triu(
torch.full((1, chunk_len, chunk_len), float("-inf")),
diagonal=1,
)
k_caches = {
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
1, cache_len, config.head_dim
)
for i in range(config.n_kv_heads)
}
v_caches = {
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
1, cache_len, config.head_dim
)
for i in range(config.n_kv_heads)
}
ys = []
for i in range(n_chunks):
y_i, attn_update = static_attn(
x[:, i * chunk_len : (i + 1) * chunk_len, :],
freqs_cos[i * chunk_len : (i + 1) * chunk_len],
freqs_sin[i * chunk_len : (i + 1) * chunk_len],
mask=mask,
in_cache_state=(k_caches, v_caches),
out_cache_state=({}, {}),

def test_with_style(style):
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
mask.tensor[:, :, cache_len:] = torch.triu(
torch.full((1, chunk_len, chunk_len), float("-inf")),
diagonal=1,
)
ys.append(y_i)
mask[:, :, cache_len - chunk_len * (i + 1) : cache_len] = 0
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
for cache_id, update in k_cache_updates.items():
k_caches[cache_id] = StaticKVCache.apply_update(
k_caches[cache_id], update
k_caches = {
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
1, cache_len, config.head_dim
)
for cache_id, update in v_cache_updates.items():
v_caches[cache_id] = StaticKVCache.apply_update(
v_caches[cache_id], update
for i in range(config.n_kv_heads)
}
v_caches = {
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
1, cache_len, config.head_dim
)

y = torch.cat(ys, dim=1)
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
for i in range(config.n_kv_heads)
}
ys = []
for i in range(n_chunks):
y_i, attn_update = static_attn(
x[:, i * chunk_len : (i + 1) * chunk_len, :],
freqs_cos[i * chunk_len : (i + 1) * chunk_len],
freqs_sin[i * chunk_len : (i + 1) * chunk_len],
mask=mask.tensor,
in_cache_state=(k_caches, v_caches),
out_cache_state=({}, {}),
)
ys.append(y_i)
mask.unmask(chunk_len)
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]

if i < n_chunks - 1:
for cache_id, update in k_cache_updates.items():
k_caches[cache_id] = StaticKVCache.apply_update(
k_caches[cache_id], update, pos=chunk_len * i, style=style
)
for cache_id, update in v_cache_updates.items():
v_caches[cache_id] = StaticKVCache.apply_update(
v_caches[cache_id], update, pos=chunk_len * i, style=style
)

y = torch.cat(ys, dim=1)
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())

test_with_style("shift_pointer")
test_with_style("smart_mask")

def test_within_transformer(self):
config = ModelArgs(
Expand Down Expand Up @@ -162,48 +169,57 @@ def test_within_transformer(self):
n_chunks = 3
chunk_len = config.max_seq_len // n_chunks
cache_len = config.max_seq_len - chunk_len
mask = torch.zeros(1, chunk_len, cache_len + chunk_len)
mask[:, :, :cache_len] = float("-inf")
mask[:, :, cache_len:] = torch.triu(
torch.full((1, chunk_len, chunk_len), float("-inf")),
diagonal=1,
)
k_caches = {
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
1, cache_len, config.head_dim
)
for layer_id in range(config.n_layers)
for i in range(config.n_kv_heads)
}
v_caches = {
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
1, cache_len, config.head_dim
)
for layer_id in range(config.n_layers)
for i in range(config.n_kv_heads)
}
ys = []
for i in range(n_chunks):
y_i, attn_update = static_transformer(
x[:, i * chunk_len : (i + 1) * chunk_len],
attn_options=ForwardOptions(
mask=mask,
freqs_cos_override=freqs_cos[i * chunk_len : (i + 1) * chunk_len],
freqs_sin_override=freqs_sin[i * chunk_len : (i + 1) * chunk_len],
in_cache_state=(k_caches, v_caches),
out_cache_state=({}, {}),
),

def test_with_style(style):
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
mask.tensor[:, :, cache_len:] = torch.triu(
torch.full((1, chunk_len, chunk_len), float("-inf")),
diagonal=1,
)
ys.append(y_i)
mask[:, :, cache_len - chunk_len * (i + 1) : cache_len] = 0
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
for cache_id, update in k_cache_updates.items():
k_caches[cache_id] = StaticKVCache.apply_update(
k_caches[cache_id], update
k_caches = {
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
1, cache_len, config.head_dim
)
for cache_id, update in v_cache_updates.items():
v_caches[cache_id] = StaticKVCache.apply_update(
v_caches[cache_id], update
for layer_id in range(config.n_layers)
for i in range(config.n_kv_heads)
}
v_caches = {
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
1, cache_len, config.head_dim
)

self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())
for layer_id in range(config.n_layers)
for i in range(config.n_kv_heads)
}
ys = []
for i in range(n_chunks):
y_i, attn_update = static_transformer(
x[:, i * chunk_len : (i + 1) * chunk_len],
attn_options=ForwardOptions(
mask=mask.tensor,
freqs_cos_override=freqs_cos[
i * chunk_len : (i + 1) * chunk_len
],
freqs_sin_override=freqs_sin[
i * chunk_len : (i + 1) * chunk_len
],
in_cache_state=(k_caches, v_caches),
out_cache_state=({}, {}),
),
)
ys.append(y_i)
mask.unmask(chunk_len)
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
if i < n_chunks - 1:
for cache_id, update in k_cache_updates.items():
k_caches[cache_id] = StaticKVCache.apply_update(
k_caches[cache_id], update, pos=chunk_len * i, style=style
)
for cache_id, update in v_cache_updates.items():
v_caches[cache_id] = StaticKVCache.apply_update(
v_caches[cache_id], update, pos=chunk_len * i, style=style
)

self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())

test_with_style("shift_pointer")
test_with_style("smart_mask")
Loading