Skip to content

Commit cb7b345

Browse files
pytorchbothelunwencser
authored andcommitted
implement position encoding for shifted tokens
Pull Request resolved: #6646 In AttentionSink, it uses tokens' positions in the KVCache instead of the actual text. When tokens get shifted in KVCache, it needs to update q and k's position embedding. In the original [implementation](https://github.com/mit-han-lab/streaming-llm) of AttentionSink with Rope, it caches the original q and k in KVCache and apply position embedding during inference. This PR adds `RopeWithAttentionSink`. It assumes that q and k are already encoded with their original position. When we shift tokens, we reapply the position delta. This has two benefits: - minimize our code since our existing `llama_transformer` applies rope embedding before doing KVCache update - avoid performance regression when tokens are not shifted because we don't need to reapply position encoding in KVCache for them ghstack-source-id: 255579838 Differential Revision: [D65366440](https://our.internmc.facebook.com/intern/diff/D65366440/) --------- Co-authored-by: Lunwen He <[email protected]>
1 parent 1ba7f51 commit cb7b345

File tree

4 files changed

+190
-0
lines changed

4 files changed

+190
-0
lines changed

examples/models/llama/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ runtime.python_library(
9393
"source_transformation/sdpa.py",
9494
"source_transformation/spin_quant.py",
9595
"source_transformation/vulkan_rope.py",
96+
"source_transformation/attention_sink.py",
9697
],
9798
_is_external_target = True,
9899
base_module = "executorch.examples.models.llama",
@@ -213,3 +214,16 @@ runtime.python_test(
213214
"//executorch/examples/models/llama:llama_transformer",
214215
],
215216
)
217+
218+
runtime.python_test(
219+
name = "attention_sink_test",
220+
srcs = [
221+
"source_transformation/test_attention_sink.py",
222+
],
223+
supports_static_listing = False,
224+
deps = [
225+
"fbsource//third-party/pypi/parameterized:parameterized",
226+
"//caffe2:torch",
227+
":export_library",
228+
],
229+
)

examples/models/llama/rope.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,22 @@ def apply_rotary_emb(
9292
return xq_out.type_as(xq), xk_out.type_as(xk)
9393

9494

95+
def apply_rotary_emb_to_k(
96+
xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
97+
) -> torch.Tensor:
98+
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
99+
100+
freqs_cos = reshape_for_broadcast(freqs_cos, xk_r)
101+
freqs_sin = reshape_for_broadcast(freqs_sin, xk_r)
102+
103+
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
104+
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
105+
106+
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
107+
108+
return xk_out.type_as(xk)
109+
110+
95111
class RotaryEmbedding(torch.nn.Module):
96112
def __init__(self):
97113
super().__init__()
@@ -160,3 +176,28 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
160176
q_embed = (q * cos) + (rotate_half(q) * sin)
161177
k_embed = (k * cos) + (rotate_half(k) * sin)
162178
return q_embed, k_embed
179+
180+
181+
def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1):
182+
"""Applies Rotary Position Embedding to the key tensors.
183+
184+
Args:
185+
k (`torch.Tensor`): The key tensor.
186+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
187+
sin (`torch.Tensor`): The sine part of the rotary embedding.
188+
position_ids (`torch.Tensor`, *optional*):
189+
Deprecated and unused.
190+
unsqueeze_dim (`int`, *optional*, defaults to 1):
191+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
192+
sin[position_ids] so that they can be properly broadcasted to the dimensions of k. For example, note
193+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if
194+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
195+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of k. Similarly, if k have
196+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
197+
Returns:
198+
`torch.Tensor` the key tensor rotated using the Rotary Position Embedding.
199+
"""
200+
cos = cos.unsqueeze(unsqueeze_dim)
201+
sin = sin.unsqueeze(unsqueeze_dim)
202+
k_embed = (k * cos) + (rotate_half(k) * sin)
203+
return k_embed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Components for supporting Attention Sink. See
8+
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.
9+
10+
import torch
11+
12+
from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope
13+
from executorch.examples.models.llama.rope import (
14+
apply_rotary_emb_to_k,
15+
hf_apply_rotary_emb_to_k,
16+
)
17+
18+
19+
class RopeWithAttentionSink(Rope):
20+
"""
21+
Rope that helps adjust position encoding when tokens are shifted in KVCache.
22+
For AttentionSink, when tokens are shifted in KVCache, we need to use positions
23+
in KVCache instead of positions in the actual text.
24+
"""
25+
26+
def __init__(self, params: ModelArgs):
27+
super().__init__(params)
28+
if self.params.use_hf_rope:
29+
self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k
30+
else:
31+
self.apply_rotary_emb_to_k = apply_rotary_emb_to_k
32+
33+
def rerotate_k(
34+
self,
35+
k: torch.Tensor,
36+
original_position: int,
37+
new_position: int,
38+
):
39+
"""
40+
Rerotate k from original_position to new_position. This is done by rerotating
41+
k with (new_position * theta - original_position * theta) with the following matrix:
42+
(cos(delta), -sin(delta)
43+
sin(delta), cos(delta))
44+
where delta = new_position * theta - original_position * theta
45+
46+
The shape of k is (batch_size, seq_len, n_local_heads, head_dim)
47+
48+
Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961
49+
"""
50+
seq_len = k.shape[1]
51+
original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len)
52+
original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len)
53+
new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len)
54+
new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len)
55+
rerotation_cos = (
56+
new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin
57+
)
58+
rerotation_sin = (
59+
new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin
60+
)
61+
62+
return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.examples.models.llama.llama_transformer import ModelArgs
11+
12+
from executorch.examples.models.llama.source_transformation.attention_sink import (
13+
RopeWithAttentionSink,
14+
)
15+
from parameterized import parameterized
16+
17+
18+
class RopeWithAttentionSinkTest(unittest.TestCase):
19+
20+
def setUp(self):
21+
torch.manual_seed(42)
22+
self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True)
23+
self.rope_with_attention_sink = RopeWithAttentionSink(params=self.params)
24+
25+
@parameterized.expand(
26+
[
27+
[128, 127], # Rotate left
28+
[128, 128], # No rotation
29+
[128, 129], # Rotate right
30+
]
31+
)
32+
def test_rotate(self, original_position, new_position):
33+
seq_len = 32
34+
35+
q = torch.rand(
36+
1, seq_len, self.params.n_heads, self.params.head_dim, dtype=torch.float32
37+
)
38+
k = torch.rand(
39+
1,
40+
seq_len,
41+
self.params.n_heads,
42+
self.params.head_dim,
43+
dtype=torch.float32,
44+
)
45+
freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs(
46+
input_pos=torch.tensor([original_position], dtype=torch.int32),
47+
seq_len=seq_len,
48+
)
49+
_, pre_rotated_k = self.rope_with_attention_sink.forward(
50+
q=q,
51+
k=k,
52+
freqs_cos=freqs_cos,
53+
freqs_sin=freqs_sin,
54+
)
55+
56+
rerotated_k = self.rope_with_attention_sink.rerotate_k(
57+
k=pre_rotated_k,
58+
original_position=original_position,
59+
new_position=new_position,
60+
)
61+
62+
freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs(
63+
input_pos=torch.tensor([new_position], dtype=torch.int32),
64+
seq_len=seq_len,
65+
)
66+
_, expected_k = self.rope_with_attention_sink.forward(
67+
q=q,
68+
k=k,
69+
freqs_cos=freqs_cos,
70+
freqs_sin=freqs_sin,
71+
)
72+
73+
torch.testing.assert_close(rerotated_k, expected_k)

0 commit comments

Comments
 (0)