Skip to content

Commit ac33011

Browse files
sxufacebook-github-bot
authored andcommitted
Support HuggingFace RoPE in static attention
Summary: Create a separate Rope forward to allow flexibility in implementation for static attention. Differential Revision: D69857290
1 parent e1aabb6 commit ac33011

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

examples/models/llama/static_attention.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,30 @@ def update(
114114
return all_data, (out_k_cache, out_v_cache)
115115

116116

117-
def _apply_rotary_embedding(
118-
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
119-
) -> torch.Tensor:
120-
x_r, x_i = x[..., ::2], x[..., 1::2]
121-
x_out_r = x_r * freqs_cos - x_i * freqs_sin
122-
x_out_i = x_r * freqs_sin + x_i * freqs_cos
117+
class _Rope(nn.Module):
118+
def __init__(self, use_hf_rope):
119+
super().__init__()
120+
self.use_hf_rope = use_hf_rope
121+
122+
def forward(
123+
self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
124+
) -> torch.Tensor:
125+
if self.use_hf_rope:
126+
if len(freqs_cos.shape) == 2:
127+
freqs_cos = freqs_cos.unsqueeze(0)
128+
if len(freqs_sin.shape) == 2:
129+
freqs_sin = freqs_sin.unsqueeze(0)
130+
x1 = x[..., : x.shape[-1] // 2]
131+
x2 = x[..., x.shape[-1] // 2 :]
132+
x_rotated = torch.cat((-x2, x1), dim=-1)
133+
return x * freqs_cos + x_rotated * freqs_sin
134+
else:
135+
x_r, x_i = x[..., ::2], x[..., 1::2]
136+
x_out_r = x_r * freqs_cos - x_i * freqs_sin
137+
x_out_i = x_r * freqs_sin + x_i * freqs_cos
123138

124-
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
125-
return x_out
139+
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
140+
return x_out
126141

127142

128143
@register_attention("static")
@@ -172,6 +187,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
172187
[StaticVCache(layer_id, i) for i in range(self.n_kv_heads)]
173188
)
174189
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
190+
self.rope = _Rope(rope.params.use_hf_rope)
175191

176192
def forward(
177193
self,
@@ -191,8 +207,8 @@ def forward(
191207
new_qs = [self.wqs[i](x) for i in range(self.n_heads)]
192208
new_ks = [self.wks[i](x) for i in range(self.n_kv_heads)]
193209
new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)]
194-
new_qs = [_apply_rotary_embedding(q, freqs_cos, freqs_sin) for q in new_qs]
195-
new_ks = [_apply_rotary_embedding(k, freqs_cos, freqs_sin) for k in new_ks]
210+
new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
211+
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]
196212

197213
all_ks = []
198214
all_vs = []

examples/models/llama/tests/test_static_attention.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,35 @@ def test_without_cache(self):
4343
)
4444
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
4545

46+
def test_hf_rope_without_cache(self):
47+
config = ModelArgs(
48+
dim=64,
49+
n_heads=4,
50+
n_kv_heads=2,
51+
max_seq_len=8,
52+
use_hf_rope=True,
53+
)
54+
layer_id = 0
55+
rope = Rope(config)
56+
attn_mha = AttentionMHA(config, layer_id, rope).eval()
57+
static_attn = StaticAttention(config, layer_id, rope).eval()
58+
static_attn.load_weights_from_attention_mha(attn_mha)
59+
60+
x = torch.rand(1, config.max_seq_len, config.dim)
61+
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
62+
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
63+
mask = torch.triu(
64+
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
65+
diagonal=1,
66+
)
67+
y, _ = static_attn(
68+
x,
69+
freqs_cos.unsqueeze(0),
70+
freqs_sin.unsqueeze(0),
71+
mask=mask,
72+
)
73+
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
74+
4675
def test_with_cache(self):
4776
config = ModelArgs(
4877
dim=64,

0 commit comments

Comments
 (0)