Skip to content

Commit dfceca9

Browse files
pytorchbothelunwencser
authored andcommitted
Implement get_freqs for RopeWithAttentionSink
This PR implements the `get_freqs` function for `RopeWithAttentionSink`. It returns the `freqs_cos` and `freqs_sin` for given `input_pos` and `seq_len` after shifting tokens in the pre-computed `freqs_cos` and `freq_sin`. Differential Revision: [D66525306](https://our.internmc.facebook.com/intern/diff/D66525306/) ghstack-source-id: 255582545 Pull Request resolved: #7100 Co-authored-by: Lunwen He <[email protected]>
1 parent cb7b345 commit dfceca9

File tree

2 files changed

+77
-3
lines changed

2 files changed

+77
-3
lines changed

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
# Components for supporting Attention Sink. See
88
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.
99

10+
from typing import Optional
11+
1012
import torch
1113

1214
from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope
@@ -23,12 +25,37 @@ class RopeWithAttentionSink(Rope):
2325
in KVCache instead of positions in the actual text.
2426
"""
2527

26-
def __init__(self, params: ModelArgs):
28+
def __init__(
29+
self,
30+
params: ModelArgs,
31+
window_size: int,
32+
sink_size: int,
33+
eviction_batch_size: int,
34+
):
2735
super().__init__(params)
2836
if self.params.use_hf_rope:
2937
self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k
3038
else:
3139
self.apply_rotary_emb_to_k = apply_rotary_emb_to_k
40+
self.max_seq_length = window_size + sink_size
41+
assert self.max_seq_length == self.params.max_seq_len
42+
self.eviction_batch_size = eviction_batch_size
43+
self.position_shift = 0
44+
45+
def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
46+
assert input_pos is not None
47+
48+
input_pos_item = input_pos.item()
49+
torch._check_is_size(input_pos_item)
50+
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
51+
# There are not enough spaces in the cache to store the new tokens.
52+
# We need to evict some old tokens and shift some recent tokens.
53+
num_to_evict = max(
54+
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
55+
self.eviction_batch_size,
56+
)
57+
self.position_shift -= num_to_evict # pyre-ignore [8]
58+
return super().get_freqs(input_pos + self.position_shift, seq_len)
3259

3360
def rerotate_k(
3461
self,

examples/models/llama/source_transformation/test_attention_sink.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,57 @@
1717

1818
class RopeWithAttentionSinkTest(unittest.TestCase):
1919

20+
def _init_rope(self, params: ModelArgs, eviction_batch_size: int):
21+
return RopeWithAttentionSink(
22+
params=params,
23+
window_size=252,
24+
sink_size=4,
25+
eviction_batch_size=eviction_batch_size,
26+
)
27+
2028
def setUp(self):
2129
torch.manual_seed(42)
22-
self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True)
23-
self.rope_with_attention_sink = RopeWithAttentionSink(params=self.params)
30+
self.params = ModelArgs(
31+
use_kv_cache=True, enable_dynamic_shape=True, max_seq_len=256
32+
)
33+
self.rope_with_attention_sink = self._init_rope(
34+
params=self.params, eviction_batch_size=1
35+
)
36+
37+
@parameterized.expand(
38+
[
39+
[0, 10, 1, 0], # No shift
40+
[250, 10, 1, 246], # Some shift
41+
[256, 10, 1, 246], # All shift
42+
[0, 10, 30, 0], # No shift with batch eviction
43+
[250, 10, 30, 220], # Some shift with batch eviction
44+
[256, 10, 30, 226], # All shift with batch eviction
45+
]
46+
)
47+
def test_get_freqs(
48+
self, input_pos, seq_len, eviction_batch_size, expected_result_pos
49+
):
50+
self.rope_with_attention_sink = self._init_rope(
51+
params=self.params, eviction_batch_size=eviction_batch_size
52+
)
53+
54+
freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs(
55+
input_pos=torch.tensor([input_pos], dtype=torch.int32),
56+
seq_len=seq_len,
57+
)
58+
59+
torch.testing.assert_close(
60+
freqs_cos,
61+
self.rope_with_attention_sink.freqs_cos.narrow(
62+
0, expected_result_pos, seq_len
63+
),
64+
)
65+
torch.testing.assert_close(
66+
freqs_sin,
67+
self.rope_with_attention_sink.freqs_sin.narrow(
68+
0, expected_result_pos, seq_len
69+
),
70+
)
2471

2572
@parameterized.expand(
2673
[

0 commit comments

Comments
 (0)