|
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 |
|
14 |
| -from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope |
| 14 | +from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope |
15 | 15 | from executorch.examples.models.llama.rope import (
|
16 | 16 | apply_rotary_emb_to_k,
|
17 | 17 | hf_apply_rotary_emb_to_k,
|
@@ -87,3 +87,122 @@ def rerotate_k(
|
87 | 87 | )
|
88 | 88 |
|
89 | 89 | return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)
|
| 90 | + |
| 91 | + |
| 92 | +class KVCacheWithAttentionSink(KVCache): |
| 93 | + """ |
| 94 | + KV cache that supports attention sink. It keeps the initial few tokens as attention sink. |
| 95 | + For other tokens, it uses a sliding window to keep the most recent tokens. |
| 96 | +
|
| 97 | + Parameters: |
| 98 | + window_size: the size of the sliding window |
| 99 | + sink_size: the number of initial tokens to keep as attention sink |
| 100 | + eviction_batch_size: the number of tokens to evict in batch when there is not enough space in the KV cache |
| 101 | + """ |
| 102 | + |
| 103 | + def __init__( |
| 104 | + self, |
| 105 | + n_heads: int, |
| 106 | + head_dim: int, |
| 107 | + transpose_cache: bool, |
| 108 | + enable_dynamic_shape: bool, |
| 109 | + rope: RopeWithAttentionSink, |
| 110 | + window_size: int, |
| 111 | + sink_size: int, |
| 112 | + eviction_batch_size: int, |
| 113 | + max_batch_size: int = 1, |
| 114 | + dtype=torch.float32, |
| 115 | + ): |
| 116 | + super().__init__( |
| 117 | + max_batch_size=max_batch_size, |
| 118 | + max_seq_length=window_size + sink_size, |
| 119 | + n_heads=n_heads, |
| 120 | + head_dim=head_dim, |
| 121 | + transpose_cache=transpose_cache, |
| 122 | + enable_dynamic_shape=enable_dynamic_shape, |
| 123 | + dtype=dtype, |
| 124 | + ) |
| 125 | + self.rope = rope |
| 126 | + self.window_size = window_size |
| 127 | + self.sink_size = sink_size |
| 128 | + self.eviction_batch_size = eviction_batch_size |
| 129 | + self.position_shift = 0 |
| 130 | + |
| 131 | + def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: |
| 132 | + """ |
| 133 | + Evict old tokens from the cache to make rooms for new tokens. |
| 134 | +
|
| 135 | + Parameters: |
| 136 | + input_pos: the start position of the incoming token in the actual sequence |
| 137 | + seq_len: the length of the incoming sequence |
| 138 | + rope: the rope object to use for rerotating k |
| 139 | +
|
| 140 | + Returns: |
| 141 | + the number of tokens to evict from the cache which is also the number of |
| 142 | + positions to shift for incoming tokens |
| 143 | + """ |
| 144 | + input_pos_item = input_pos.item() |
| 145 | + torch._check_is_size(input_pos_item) |
| 146 | + if input_pos_item + self.position_shift + seq_len > self.max_seq_length: |
| 147 | + # There are not enough spaces in the cache to store the new tokens. |
| 148 | + # We need to evict some old tokens and shift some recent tokens. |
| 149 | + num_to_evict = max( |
| 150 | + input_pos_item + self.position_shift - self.max_seq_length + seq_len, |
| 151 | + self.eviction_batch_size, |
| 152 | + ) |
| 153 | + num_to_keep = ( |
| 154 | + input_pos_item + self.position_shift - self.sink_size - num_to_evict |
| 155 | + ) |
| 156 | + num_empty_space = self.window_size - num_to_keep |
| 157 | + dim_to_slice = 2 if self.transpose_cache else 1 |
| 158 | + k_to_keep = self.k_cache.narrow( |
| 159 | + dim_to_slice, |
| 160 | + self.sink_size + num_to_evict, # pyre-ignore [6] |
| 161 | + num_to_keep, # pyre-ignore [6] |
| 162 | + ) |
| 163 | + if self.transpose_cache: |
| 164 | + k_to_keep = self.rope.rerotate_k( |
| 165 | + k=k_to_keep.transpose(1, 2), |
| 166 | + original_position=( # pyre-ignore [6] |
| 167 | + self.sink_size + num_to_evict |
| 168 | + ), |
| 169 | + new_position=self.sink_size, |
| 170 | + ).transpose(1, 2) |
| 171 | + else: |
| 172 | + k_to_keep = self.rope.rerotate_k( |
| 173 | + k=k_to_keep, |
| 174 | + original_position=( # pyre-ignore [6] |
| 175 | + self.sink_size + num_to_evict |
| 176 | + ), |
| 177 | + new_position=self.sink_size, |
| 178 | + ) |
| 179 | + self.k_cache = torch.cat( |
| 180 | + [ |
| 181 | + self.k_cache.narrow(dim_to_slice, 0, self.sink_size), |
| 182 | + k_to_keep, |
| 183 | + torch.zeros_like( |
| 184 | + self.k_cache.narrow( |
| 185 | + dim_to_slice, 0, num_empty_space # pyre-ignore [6] |
| 186 | + ) |
| 187 | + ), |
| 188 | + ], |
| 189 | + dim=dim_to_slice, |
| 190 | + ) |
| 191 | + self.v_cache = torch.cat( |
| 192 | + [ |
| 193 | + self.v_cache.narrow(dim_to_slice, 0, self.sink_size), |
| 194 | + self.v_cache.narrow( |
| 195 | + dim_to_slice, |
| 196 | + self.sink_size + num_to_evict, # pyre-ignore [6] |
| 197 | + num_to_keep, # pyre-ignore [6] |
| 198 | + ), |
| 199 | + torch.zeros_like( |
| 200 | + self.v_cache.narrow( |
| 201 | + dim_to_slice, 0, num_empty_space # pyre-ignore [6] |
| 202 | + ) |
| 203 | + ), |
| 204 | + ], |
| 205 | + dim=dim_to_slice, |
| 206 | + ) |
| 207 | + self.position_shift -= num_to_evict # pyre-ignore [8] |
| 208 | + return self.position_shift |
0 commit comments