Skip to content

Commit 463119e

Browse files
authored
Add smart mask style KVCache and mask
Differential Revision: D69595959 Pull Request resolved: #8463
1 parent 139be81 commit 463119e

File tree

2 files changed

+154
-91
lines changed

2 files changed

+154
-91
lines changed

examples/models/llama/static_attention.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,29 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str:
4747
return f"l{layer_id},h{head_id}"
4848

4949
@staticmethod
50-
def apply_update(cache, update, transpose=False):
50+
def apply_update(cache, update, pos, style, transpose=False):
5151
"""
5252
After inference, update the cache state for next iteration. The runtime needs to
5353
implement the same operation.
5454
"""
55-
if transpose:
56-
update_len = update.size(-1)
57-
updated = torch.roll(cache, -update_len, -1)
58-
updated[:, :, -update_len:] = update
59-
else:
60-
update_len = update.size(-2)
61-
updated = torch.roll(cache, -update_len, -2)
62-
updated[:, -update_len:, :] = update
55+
if style == "shift_pointer":
56+
if transpose:
57+
update_len = update.size(-1)
58+
updated = torch.roll(cache, -update_len, -1)
59+
updated[:, :, -update_len:] = update
60+
else:
61+
update_len = update.size(-2)
62+
updated = torch.roll(cache, -update_len, -2)
63+
updated[:, -update_len:, :] = update
64+
65+
if style == "smart_mask":
66+
updated = torch.clone(cache)
67+
if transpose:
68+
update_len = update.size(-1)
69+
updated[:, :, pos : pos + update_len] = update
70+
else:
71+
update_len = update.size(-2)
72+
updated[:, pos : pos + update_len, :] = update
6373

6474
return updated
6575

@@ -114,6 +124,44 @@ def update(
114124
return all_data, (out_k_cache, out_v_cache)
115125

116126

127+
class StaticAttentionMask:
128+
def __init__(self, input_len, cache_len, style):
129+
self.input_len = input_len
130+
self.cache_len = cache_len
131+
assert style in ("shift_pointer", "smart_mask")
132+
self.style = style
133+
self.unmasked_len = 0
134+
self.tensor = torch.zeros(1, input_len, input_len + cache_len)
135+
self.reset()
136+
137+
def reset(self):
138+
self.unmasked_len = 0
139+
self.tensor[:, :, : self.cache_len] = float("-inf")
140+
141+
def unmask(self, new_unmasked_len):
142+
if new_unmasked_len <= 0:
143+
return
144+
145+
if self.style == "shift_pointer":
146+
self.tensor[
147+
:,
148+
:,
149+
self.cache_len
150+
- self.unmasked_len
151+
- new_unmasked_len : self.cache_len
152+
- self.unmasked_len,
153+
] = 0
154+
155+
if self.style == "smart_mask":
156+
self.tensor[
157+
:,
158+
:,
159+
self.unmasked_len : self.unmasked_len + new_unmasked_len,
160+
] = 0
161+
162+
self.unmasked_len += new_unmasked_len
163+
164+
117165
class _Rope(nn.Module):
118166
def __init__(self, use_hf_rope):
119167
super().__init__()
@@ -135,7 +183,6 @@ def forward(
135183
x_r, x_i = x[..., ::2], x[..., 1::2]
136184
x_out_r = x_r * freqs_cos - x_i * freqs_sin
137185
x_out_i = x_r * freqs_sin + x_i * freqs_cos
138-
139186
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
140187
return x_out
141188

examples/models/llama/tests/test_static_attention.py

Lines changed: 97 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from executorch.examples.models.llama.rope import Rope
88
from executorch.examples.models.llama.static_attention import (
99
StaticAttention,
10+
StaticAttentionMask,
1011
StaticKVCache,
1112
)
1213

@@ -92,48 +93,54 @@ def test_with_cache(self):
9293
n_chunks = 3
9394
chunk_len = config.max_seq_len // n_chunks
9495
cache_len = config.max_seq_len - chunk_len
95-
mask = torch.zeros(1, chunk_len, cache_len + chunk_len)
96-
mask[:, :, :cache_len] = float("-inf")
97-
mask[:, :, cache_len:] = torch.triu(
98-
torch.full((1, chunk_len, chunk_len), float("-inf")),
99-
diagonal=1,
100-
)
101-
k_caches = {
102-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
103-
1, cache_len, config.head_dim
104-
)
105-
for i in range(config.n_kv_heads)
106-
}
107-
v_caches = {
108-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
109-
1, cache_len, config.head_dim
110-
)
111-
for i in range(config.n_kv_heads)
112-
}
113-
ys = []
114-
for i in range(n_chunks):
115-
y_i, attn_update = static_attn(
116-
x[:, i * chunk_len : (i + 1) * chunk_len, :],
117-
freqs_cos[i * chunk_len : (i + 1) * chunk_len],
118-
freqs_sin[i * chunk_len : (i + 1) * chunk_len],
119-
mask=mask,
120-
in_cache_state=(k_caches, v_caches),
121-
out_cache_state=({}, {}),
96+
97+
def test_with_style(style):
98+
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
99+
mask.tensor[:, :, cache_len:] = torch.triu(
100+
torch.full((1, chunk_len, chunk_len), float("-inf")),
101+
diagonal=1,
122102
)
123-
ys.append(y_i)
124-
mask[:, :, cache_len - chunk_len * (i + 1) : cache_len] = 0
125-
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
126-
for cache_id, update in k_cache_updates.items():
127-
k_caches[cache_id] = StaticKVCache.apply_update(
128-
k_caches[cache_id], update
103+
k_caches = {
104+
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
105+
1, cache_len, config.head_dim
129106
)
130-
for cache_id, update in v_cache_updates.items():
131-
v_caches[cache_id] = StaticKVCache.apply_update(
132-
v_caches[cache_id], update
107+
for i in range(config.n_kv_heads)
108+
}
109+
v_caches = {
110+
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
111+
1, cache_len, config.head_dim
133112
)
134-
135-
y = torch.cat(ys, dim=1)
136-
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
113+
for i in range(config.n_kv_heads)
114+
}
115+
ys = []
116+
for i in range(n_chunks):
117+
y_i, attn_update = static_attn(
118+
x[:, i * chunk_len : (i + 1) * chunk_len, :],
119+
freqs_cos[i * chunk_len : (i + 1) * chunk_len],
120+
freqs_sin[i * chunk_len : (i + 1) * chunk_len],
121+
mask=mask.tensor,
122+
in_cache_state=(k_caches, v_caches),
123+
out_cache_state=({}, {}),
124+
)
125+
ys.append(y_i)
126+
mask.unmask(chunk_len)
127+
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
128+
129+
if i < n_chunks - 1:
130+
for cache_id, update in k_cache_updates.items():
131+
k_caches[cache_id] = StaticKVCache.apply_update(
132+
k_caches[cache_id], update, pos=chunk_len * i, style=style
133+
)
134+
for cache_id, update in v_cache_updates.items():
135+
v_caches[cache_id] = StaticKVCache.apply_update(
136+
v_caches[cache_id], update, pos=chunk_len * i, style=style
137+
)
138+
139+
y = torch.cat(ys, dim=1)
140+
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
141+
142+
test_with_style("shift_pointer")
143+
test_with_style("smart_mask")
137144

138145
def test_within_transformer(self):
139146
config = ModelArgs(
@@ -162,48 +169,57 @@ def test_within_transformer(self):
162169
n_chunks = 3
163170
chunk_len = config.max_seq_len // n_chunks
164171
cache_len = config.max_seq_len - chunk_len
165-
mask = torch.zeros(1, chunk_len, cache_len + chunk_len)
166-
mask[:, :, :cache_len] = float("-inf")
167-
mask[:, :, cache_len:] = torch.triu(
168-
torch.full((1, chunk_len, chunk_len), float("-inf")),
169-
diagonal=1,
170-
)
171-
k_caches = {
172-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
173-
1, cache_len, config.head_dim
174-
)
175-
for layer_id in range(config.n_layers)
176-
for i in range(config.n_kv_heads)
177-
}
178-
v_caches = {
179-
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
180-
1, cache_len, config.head_dim
181-
)
182-
for layer_id in range(config.n_layers)
183-
for i in range(config.n_kv_heads)
184-
}
185-
ys = []
186-
for i in range(n_chunks):
187-
y_i, attn_update = static_transformer(
188-
x[:, i * chunk_len : (i + 1) * chunk_len],
189-
attn_options=ForwardOptions(
190-
mask=mask,
191-
freqs_cos_override=freqs_cos[i * chunk_len : (i + 1) * chunk_len],
192-
freqs_sin_override=freqs_sin[i * chunk_len : (i + 1) * chunk_len],
193-
in_cache_state=(k_caches, v_caches),
194-
out_cache_state=({}, {}),
195-
),
172+
173+
def test_with_style(style):
174+
mask = StaticAttentionMask(chunk_len, cache_len, style=style)
175+
mask.tensor[:, :, cache_len:] = torch.triu(
176+
torch.full((1, chunk_len, chunk_len), float("-inf")),
177+
diagonal=1,
196178
)
197-
ys.append(y_i)
198-
mask[:, :, cache_len - chunk_len * (i + 1) : cache_len] = 0
199-
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
200-
for cache_id, update in k_cache_updates.items():
201-
k_caches[cache_id] = StaticKVCache.apply_update(
202-
k_caches[cache_id], update
179+
k_caches = {
180+
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
181+
1, cache_len, config.head_dim
203182
)
204-
for cache_id, update in v_cache_updates.items():
205-
v_caches[cache_id] = StaticKVCache.apply_update(
206-
v_caches[cache_id], update
183+
for layer_id in range(config.n_layers)
184+
for i in range(config.n_kv_heads)
185+
}
186+
v_caches = {
187+
StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros(
188+
1, cache_len, config.head_dim
207189
)
208-
209-
self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())
190+
for layer_id in range(config.n_layers)
191+
for i in range(config.n_kv_heads)
192+
}
193+
ys = []
194+
for i in range(n_chunks):
195+
y_i, attn_update = static_transformer(
196+
x[:, i * chunk_len : (i + 1) * chunk_len],
197+
attn_options=ForwardOptions(
198+
mask=mask.tensor,
199+
freqs_cos_override=freqs_cos[
200+
i * chunk_len : (i + 1) * chunk_len
201+
],
202+
freqs_sin_override=freqs_sin[
203+
i * chunk_len : (i + 1) * chunk_len
204+
],
205+
in_cache_state=(k_caches, v_caches),
206+
out_cache_state=({}, {}),
207+
),
208+
)
209+
ys.append(y_i)
210+
mask.unmask(chunk_len)
211+
k_cache_updates, v_cache_updates = attn_update["out_cache_state"]
212+
if i < n_chunks - 1:
213+
for cache_id, update in k_cache_updates.items():
214+
k_caches[cache_id] = StaticKVCache.apply_update(
215+
k_caches[cache_id], update, pos=chunk_len * i, style=style
216+
)
217+
for cache_id, update in v_cache_updates.items():
218+
v_caches[cache_id] = StaticKVCache.apply_update(
219+
v_caches[cache_id], update, pos=chunk_len * i, style=style
220+
)
221+
222+
self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())
223+
224+
test_with_style("shift_pointer")
225+
test_with_style("smart_mask")

0 commit comments

Comments
 (0)