Skip to content

Commit 4304f5a

Browse files
committed
[Executorch][llm] Enable leveraging ring kv cache via module swap
This allows us to make some of the attention modules to use sliding window kv cache. Will help enable models like gemma3. Differential Revision: [D73891426](https://our.internmc.facebook.com/intern/diff/D73891426/) ghstack-source-id: 281455703 Pull Request resolved: #10611
1 parent 2582abc commit 4304f5a

File tree

5 files changed

+522
-31
lines changed

5 files changed

+522
-31
lines changed

examples/models/llama/attention.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ def forward(
150150
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
151151

152152

153+
def _create_causal_mask_for_ring_buffer(
154+
cache_positions, window_size, start_pos, seq_len
155+
):
156+
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
157+
delta = pos_q - cache_positions
158+
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < window_size)
159+
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
160+
return attn_mask
161+
162+
153163
class CacheUpdateStrategy(Enum):
154164
RING_BUFFER = "RingBuffer"
155165
INVALID = "Invalid"
@@ -283,12 +293,10 @@ def __init__(
283293
self.is_ring_buffer = True
284294

285295
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
286-
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
287296
cache_positions = self.cache_positions_manager.cache_positions
288-
delta = pos_q - cache_positions
289-
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size)
290-
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
291-
return attn_mask
297+
return _create_causal_mask_for_ring_buffer(
298+
cache_positions, self.window_size, start_pos, seq_len
299+
)
292300

293301
def update(
294302
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
from executorch.examples.models.llama.attention import KVCache
13+
from executorch.examples.models.llama.attention import (
14+
_create_causal_mask_for_ring_buffer,
15+
CachePositionsManager,
16+
KVCache,
17+
RingKVCache,
18+
)
1419

1520
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1621

@@ -75,6 +80,7 @@ def __init__(
7580
self.register_buffer(
7681
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
7782
)
83+
self.cache_type = cache_type
7884

7985
def _quantize(self, value):
8086
(
@@ -181,6 +187,7 @@ def update(self, input_pos, k_val, v_val, indices=None):
181187
However the storage is [B, S, H, D] so we incur transpose in, transpose out
182188
This shall be removed by subsequent post-export graph pass
183189
"""
190+
184191
k_val = k_val.transpose(1, 2)
185192
v_val = v_val.transpose(1, 2)
186193

@@ -346,3 +353,185 @@ def _replace_kv_cache_with_custom_kv_cache(module):
346353
else:
347354
_replace_kv_cache_with_custom_kv_cache(child)
348355
return module
356+
357+
358+
class QuantizedRingKVCache(QuantizedKVCache):
359+
def __init__(
360+
self,
361+
max_batch_size,
362+
max_context_length,
363+
n_heads,
364+
head_dim,
365+
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
366+
use_custom_update_cache_op: bool = False,
367+
):
368+
# Look at attention.py for explanation on why max_context_length * 2
369+
super().__init__(
370+
max_batch_size,
371+
max_context_length * 2,
372+
n_heads,
373+
head_dim,
374+
cache_type,
375+
use_custom_update_cache_op,
376+
)
377+
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
378+
self.is_ring_buffer = True
379+
self.window_size = max_context_length
380+
381+
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
382+
cache_positions = self.cache_positions_manager.cache_positions
383+
return _create_causal_mask_for_ring_buffer(
384+
cache_positions, self.window_size, start_pos, seq_len
385+
)
386+
387+
def update(self, input_pos, k_val, v_val):
388+
"""
389+
k_val, v_val: [B, H, S, D]
390+
return: [B, H, S, D]
391+
However the storage is [B, S, H, D] so we incur transpose in, transpose out
392+
This shall be removed by subsequent post-export graph pass
393+
"""
394+
# Need to transpose for two reasons
395+
# 1. kv cache is stored as [B, S, H, D]
396+
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
397+
# away transpose at the output of k, v projection
398+
seq_len = k_val.transpose(1, 2).size(1)
399+
assert seq_len <= self.k_cache.size(
400+
1
401+
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
402+
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
403+
input_pos, seq_len
404+
)
405+
indices = indices.unsqueeze(0)
406+
407+
return super().update(input_pos, k_val, v_val, indices)
408+
409+
@classmethod
410+
def from_quantized_kv_cache(
411+
cls,
412+
kv_cache,
413+
sliding_window_size,
414+
):
415+
assert isinstance(
416+
kv_cache, QuantizedKVCache
417+
), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache"
418+
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
419+
return cls(
420+
max_batch_size,
421+
sliding_window_size,
422+
n_heads,
423+
head_dim,
424+
kv_cache.cache_type,
425+
kv_cache.use_custom_update_cache_op,
426+
)
427+
428+
429+
class CustomRingKVCache(CustomKVCache):
430+
def __init__(
431+
self,
432+
max_batch_size,
433+
max_context_length,
434+
n_heads,
435+
head_dim,
436+
dtype=torch.float32,
437+
):
438+
# Look at attention.py for explanation on why max_context_length * 2
439+
super().__init__(
440+
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype
441+
)
442+
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
443+
self.is_ring_buffer = True
444+
self.window_size = max_context_length
445+
446+
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
447+
cache_positions = self.cache_positions_manager.cache_positions
448+
return _create_causal_mask_for_ring_buffer(
449+
cache_positions, self.window_size, start_pos, seq_len
450+
)
451+
452+
def update(self, input_pos, k_val, v_val):
453+
"""
454+
k_val, v_val: [B, H, S, D]
455+
return: [B, H, S, D]
456+
However the storage is [B, S, H, D] so we incur transpose in, transpose out
457+
This shall be removed by subsequent post-export graph pass
458+
"""
459+
# Need to transpose for two reasons
460+
# 1. kv cache is stored as [B, S, H, D]
461+
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
462+
# away transpose at the output of k, v projection
463+
seq_len = k_val.transpose(1, 2).size(1)
464+
assert seq_len <= self.k_cache.size(
465+
1
466+
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
467+
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
468+
input_pos, seq_len
469+
)
470+
indices = indices.unsqueeze(0)
471+
472+
return super().update(input_pos, k_val, v_val, indices)
473+
474+
@classmethod
475+
def from_custom_kv_cache(
476+
cls,
477+
kv_cache,
478+
sliding_window_size,
479+
):
480+
max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape
481+
if isinstance(kv_cache, CustomKVCache):
482+
# If replacing custom kv cache, then the shape is [B, S, H, D]
483+
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
484+
return cls(
485+
max_batch_size,
486+
sliding_window_size,
487+
n_heads,
488+
head_dim,
489+
dtype=kv_cache.k_cache.dtype,
490+
)
491+
492+
493+
def _replace_kv_cache_with_ring_kv_cache(attention, layer_size):
494+
sliding_window_size = layer_size
495+
assert (
496+
getattr(attention, "kv_cache", None) is not None
497+
), "Attention module must have kv_cache module"
498+
kv_cache = attention.kv_cache
499+
if isinstance(kv_cache, KVCache):
500+
attention.kv_cache = RingKVCache(
501+
kv_cache.max_batch_size,
502+
sliding_window_size,
503+
kv_cache.n_heads,
504+
kv_cache.head_dim,
505+
kv_cache.enable_dynamic_shape,
506+
kv_cache.k_cache.dtype,
507+
)
508+
elif isinstance(kv_cache, CustomKVCache):
509+
attention.kv_cache = CustomRingKVCache.from_custom_kv_cache(
510+
kv_cache, layer_size
511+
)
512+
elif isinstance(kv_cache, QuantizedKVCache):
513+
attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache(
514+
kv_cache, layer_size
515+
)
516+
517+
518+
def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
519+
# This is needed to ensure that custom ops are registered
520+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
521+
522+
logging.info(
523+
"Replacing kv cache with ring kv cache. This modifies the model in place."
524+
)
525+
assert len(layer_sizes) == len(
526+
module.layers
527+
), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}."
528+
for i, transformer_block in enumerate(module.layers):
529+
sliding_window_size = layer_sizes[i]
530+
if sliding_window_size == 0:
531+
continue
532+
assert (
533+
getattr(transformer_block, "attention", None) is not None
534+
), f"Transfomer block must have attention module. Transformer block {transformer_block}"
535+
attention = transformer_block.attention
536+
_replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size)
537+
return module

examples/models/llama/tests/TARGETS

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,33 @@ python_unittest(
5555
srcs = [
5656
"test_ring_attention.py",
5757
],
58+
preload_deps = [
59+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
60+
"//executorch/kernels/quantized:aot_lib",
61+
],
5862
deps = [
5963
"//caffe2:torch",
64+
"//executorch/examples/models/llama:export_library",
65+
"//executorch/examples/models/llama:llama_transformer",
66+
"//executorch/examples/models/llama:custom_kv_cache",
67+
"//executorch/examples/models/llama:sdpa",
68+
],
69+
)
70+
71+
python_unittest(
72+
name = "test_replace_kv_cache",
73+
srcs = [
74+
"test_replace_kv_cache.py",
75+
],
76+
preload_deps = [
77+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
78+
"//executorch/kernels/quantized:aot_lib",
79+
],
80+
deps = [
81+
"//caffe2:torch",
82+
"//executorch/examples/models/llama:export_library",
6083
"//executorch/examples/models/llama:llama_transformer",
84+
"//executorch/examples/models/llama:custom_kv_cache",
85+
"//executorch/examples/models/llama:sdpa",
6186
],
6287
)

0 commit comments

Comments
 (0)