Skip to content

Commit 079c3d2

Browse files
committed
[Executorch][llm] Add support for ring kv cache and ring attention
Introduced CachePositionManager to keep track of what is the position for each slot in ring kv cache. This is used to generate mask. Differential Revision: [D73891427](https://our.internmc.facebook.com/intern/diff/D73891427/) ghstack-source-id: 281102610 Pull Request resolved: #10608
1 parent 0165a02 commit 079c3d2

File tree

3 files changed

+595
-0
lines changed

3 files changed

+595
-0
lines changed

examples/models/llama/attention.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from enum import Enum
23
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
34

45
import torch
@@ -160,6 +161,112 @@ def forward(
160161
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
161162

162163

164+
class CacheUpdateStrategy(Enum):
165+
RING_BUFFER = "RingBuffer"
166+
INVALID = "Invalid"
167+
168+
169+
class CachePositionsManager(nn.Module):
170+
def __init__(
171+
self,
172+
max_context_length: int,
173+
cache_update_strategy: CacheUpdateStrategy = CacheUpdateStrategy.RING_BUFFER,
174+
):
175+
super().__init__()
176+
assert (
177+
cache_update_strategy == CacheUpdateStrategy.RING_BUFFER
178+
), "Only RingBuffer is supported"
179+
self.max_context_length = max_context_length
180+
self.register_buffer(
181+
"cache_positions",
182+
torch.zeros((self.max_context_length), dtype=torch.long, device="cpu"),
183+
)
184+
185+
def calculate_positions_and_update_indices(self, input_pos: torch.Tensor, seq_len):
186+
"""
187+
Calculate indices, into k_cache, v_cache, where to put k_val tensor.
188+
Given the input_pos and length of k_val at sequence dim, the input pos may
189+
have to wrap around if it is smaller than the cache capacity.
190+
If it is larger than the cache capacity then just pick the last
191+
self.max_context_length entries.
192+
193+
Additionally:
194+
Update the cache positions buffer with the new indices.
195+
Given the cache positions in sequence dim, indicated by indices,
196+
we can just update cache_positions buffer using orig_indices.
197+
For example
198+
Given cache capacity of 4 and update of length 3 with start_pos = 2
199+
will have following values
200+
indices = [2, 3, 0]
201+
orig_indices = [2, 3, 4]
202+
So cache_positions after the update will be [4, 1, 2, 3]
203+
Note cache_positions[1] = 1 that is from previous write to the cache.
204+
The corner case here is cache positions before cache rolls over.
205+
For example when start_pos = 0 and update is of length 2, then we have
206+
filled positions 0 and 1 in the buffer, while the rest are invalid. In this case
207+
we have
208+
indices = [0, 1]
209+
orig_indices = [0, 1]
210+
But if we have cache_positins = [0, 1, 0, 0] that is not valid. Hence we have
211+
to make sure that invalid positions have a sentinel value of - 1.
212+
"""
213+
start_pos = input_pos[0].item()
214+
torch._check_is_size(start_pos)
215+
orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos
216+
indices = orig_indices % self.max_context_length
217+
218+
full_t = torch.full((self.max_context_length,), -1, dtype=torch.long)
219+
arange_tensor = torch.arange(self.max_context_length, dtype=torch.long)
220+
cache_positions = torch.where(
221+
arange_tensor < start_pos, self.cache_positions, full_t
222+
)
223+
self.cache_positions.copy_(cache_positions)
224+
self.cache_positions.index_copy_(0, indices, orig_indices)
225+
226+
return indices
227+
228+
229+
class RingKVCache(KVCache):
230+
def __init__(
231+
self,
232+
max_batch_size: int,
233+
max_context_length: int,
234+
n_heads: int,
235+
head_dim: int,
236+
enable_dynamic_shape: bool,
237+
dtype=torch.float32,
238+
):
239+
super().__init__(
240+
max_batch_size,
241+
max_context_length,
242+
n_heads,
243+
head_dim,
244+
enable_dynamic_shape,
245+
dtype,
246+
)
247+
self.cache_positions_manager = CachePositionsManager(max_context_length)
248+
249+
def update(
250+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
251+
) -> Tuple[torch.Tensor, torch.Tensor]:
252+
# input_pos: [S], k_val: [B, H, S, D]
253+
seq_len = k_val.size(2)
254+
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
255+
input_pos, seq_len
256+
)
257+
if self.enable_dynamic_shape:
258+
start_pos = input_pos[0].item()
259+
torch._check_is_size(start_pos)
260+
261+
self.k_cache.index_copy_(2, indices, k_val)
262+
self.v_cache.index_copy_(2, indices, v_val)
263+
else:
264+
self.k_cache[:, :, indices] = k_val
265+
self.v_cache[:, :, indices] = v_val
266+
267+
return self.k_cache, self.v_cache
268+
269+
163270
@register_attention("mha")
164271
class AttentionMHA(Attention):
165272
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):

examples/models/llama/tests/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,14 @@ python_unittest(
3838
"//executorch/examples/models/llama:static_attention",
3939
],
4040
)
41+
42+
python_unittest(
43+
name = "test_ring_kv_cache",
44+
srcs = [
45+
"test_ring_kv_cache.py",
46+
],
47+
deps = [
48+
"//caffe2:torch",
49+
"//executorch/examples/models/llama:llama_transformer",
50+
],
51+
)

0 commit comments

Comments
 (0)