Skip to content

Commit 9d084c4

Browse files
add attention_sink.py
add KVCacheWithAttentionSink Pull Request resolved: #6579 This PR adds `KVCacheWithAttentionSink`, which is required for `AttentionSink`. It keeps the first `sink_size` tokens as attention sinks and maintains a sliding window with `window_size` for new tokens. Note: I am trying to implement and verify `AttentionSink` in eager mode first. So the current implementation may still have some lower errors. Will leave these problems to resolve when we are ready to deploy `AttentionSink` to edge. ghstack-source-id: 255715047 @exported-using-ghexport Differential Revision: [D65235798](https://our.internmc.facebook.com/intern/diff/D65235798/) Co-authored-by: Lunwen He <[email protected]>
1 parent 2d499b3 commit 9d084c4

File tree

2 files changed

+586
-8
lines changed

2 files changed

+586
-8
lines changed

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313

14-
from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope
14+
from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope
1515
from executorch.examples.models.llama.rope import (
1616
apply_rotary_emb_to_k,
1717
hf_apply_rotary_emb_to_k,
@@ -87,3 +87,122 @@ def rerotate_k(
8787
)
8888

8989
return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)
90+
91+
92+
class KVCacheWithAttentionSink(KVCache):
93+
"""
94+
KV cache that supports attention sink. It keeps the initial few tokens as attention sink.
95+
For other tokens, it uses a sliding window to keep the most recent tokens.
96+
97+
Parameters:
98+
window_size: the size of the sliding window
99+
sink_size: the number of initial tokens to keep as attention sink
100+
eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache
101+
"""
102+
103+
def __init__(
104+
self,
105+
n_heads: int,
106+
head_dim: int,
107+
transpose_cache: bool,
108+
enable_dynamic_shape: bool,
109+
rope: RopeWithAttentionSink,
110+
window_size: int,
111+
sink_size: int,
112+
eviction_batch_size: int,
113+
max_batch_size: int = 1,
114+
dtype=torch.float32,
115+
):
116+
super().__init__(
117+
max_batch_size=max_batch_size,
118+
max_seq_length=window_size + sink_size,
119+
n_heads=n_heads,
120+
head_dim=head_dim,
121+
transpose_cache=transpose_cache,
122+
enable_dynamic_shape=enable_dynamic_shape,
123+
dtype=dtype,
124+
)
125+
self.rope = rope
126+
self.window_size = window_size
127+
self.sink_size = sink_size
128+
self.eviction_batch_size = eviction_batch_size
129+
self.position_shift = 0
130+
131+
def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
132+
"""
133+
Evict old tokens from the cache to make rooms for new tokens.
134+
135+
Parameters:
136+
input_pos: the start position of the incoming token in the actual sequence
137+
seq_len: the length of the incoming sequence
138+
rope: the rope object to use for rerotating k
139+
140+
Returns:
141+
the number of tokens to evict from the cache which is also the number of
142+
positions to shift for incoming tokens
143+
"""
144+
input_pos_item = input_pos.item()
145+
torch._check_is_size(input_pos_item)
146+
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
147+
# There are not enough spaces in the cache to store the new tokens.
148+
# We need to evict some old tokens and shift some recent tokens.
149+
num_to_evict = max(
150+
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
151+
self.eviction_batch_size,
152+
)
153+
num_to_keep = (
154+
input_pos_item + self.position_shift - self.sink_size - num_to_evict
155+
)
156+
num_empty_space = self.window_size - num_to_keep
157+
dim_to_slice = 2 if self.transpose_cache else 1
158+
k_to_keep = self.k_cache.narrow(
159+
dim_to_slice,
160+
self.sink_size + num_to_evict, # pyre-ignore [6]
161+
num_to_keep, # pyre-ignore [6]
162+
)
163+
if self.transpose_cache:
164+
k_to_keep = self.rope.rerotate_k(
165+
k=k_to_keep.transpose(1, 2),
166+
original_position=( # pyre-ignore [6]
167+
self.sink_size + num_to_evict
168+
),
169+
new_position=self.sink_size,
170+
).transpose(1, 2)
171+
else:
172+
k_to_keep = self.rope.rerotate_k(
173+
k=k_to_keep,
174+
original_position=( # pyre-ignore [6]
175+
self.sink_size + num_to_evict
176+
),
177+
new_position=self.sink_size,
178+
)
179+
self.k_cache = torch.cat(
180+
[
181+
self.k_cache.narrow(dim_to_slice, 0, self.sink_size),
182+
k_to_keep,
183+
torch.zeros_like(
184+
self.k_cache.narrow(
185+
dim_to_slice, 0, num_empty_space # pyre-ignore [6]
186+
)
187+
),
188+
],
189+
dim=dim_to_slice,
190+
)
191+
self.v_cache = torch.cat(
192+
[
193+
self.v_cache.narrow(dim_to_slice, 0, self.sink_size),
194+
self.v_cache.narrow(
195+
dim_to_slice,
196+
self.sink_size + num_to_evict, # pyre-ignore [6]
197+
num_to_keep, # pyre-ignore [6]
198+
),
199+
torch.zeros_like(
200+
self.v_cache.narrow(
201+
dim_to_slice, 0, num_empty_space # pyre-ignore [6]
202+
)
203+
),
204+
],
205+
dim=dim_to_slice,
206+
)
207+
self.position_shift -= num_to_evict # pyre-ignore [8]
208+
return self.position_shift

0 commit comments

Comments
 (0)