Skip to content

Support HuggingFace RoPE in static attention #8569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,30 @@ def update(
return all_data, (out_k_cache, out_v_cache)


def _apply_rotary_embedding(
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> torch.Tensor:
x_r, x_i = x[..., ::2], x[..., 1::2]
x_out_r = x_r * freqs_cos - x_i * freqs_sin
x_out_i = x_r * freqs_sin + x_i * freqs_cos
class _Rope(nn.Module):
def __init__(self, use_hf_rope):
super().__init__()
self.use_hf_rope = use_hf_rope

def forward(
self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> torch.Tensor:
if self.use_hf_rope:
if len(freqs_cos.shape) == 2:
freqs_cos = freqs_cos.unsqueeze(0)
if len(freqs_sin.shape) == 2:
freqs_sin = freqs_sin.unsqueeze(0)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x_rotated = torch.cat((-x2, x1), dim=-1)
return x * freqs_cos + x_rotated * freqs_sin
else:
x_r, x_i = x[..., ::2], x[..., 1::2]
x_out_r = x_r * freqs_cos - x_i * freqs_sin
x_out_i = x_r * freqs_sin + x_i * freqs_cos

x_out = torch.cat([x_out_r, x_out_i], dim=-1)
return x_out
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
return x_out


@register_attention("static")
Expand Down Expand Up @@ -172,6 +187,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
[StaticVCache(layer_id, i) for i in range(self.n_kv_heads)]
)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
self.rope = _Rope(rope.params.use_hf_rope)

def forward(
self,
Expand All @@ -191,8 +207,8 @@ def forward(
new_qs = [self.wqs[i](x) for i in range(self.n_heads)]
new_ks = [self.wks[i](x) for i in range(self.n_kv_heads)]
new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)]
new_qs = [_apply_rotary_embedding(q, freqs_cos, freqs_sin) for q in new_qs]
new_ks = [_apply_rotary_embedding(k, freqs_cos, freqs_sin) for k in new_ks]
new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]

all_ks = []
all_vs = []
Expand All @@ -211,7 +227,7 @@ def forward(
kv_idx = i // self.n_heads_per_kv_group
attn = new_qs[i] @ all_ks[kv_idx].transpose(-2, -1)
attn = attn * self.inv_scale
attn = attn + mask # pyre-ignore
attn = attn + mask
attn = F.softmax(attn, dim=-1)
heads.append(attn @ all_vs[kv_idx])

Expand Down
29 changes: 29 additions & 0 deletions examples/models/llama/tests/test_static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,35 @@ def test_without_cache(self):
)
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())

def test_hf_rope_without_cache(self):
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=8,
use_hf_rope=True,
)
layer_id = 0
rope = Rope(config)
attn_mha = AttentionMHA(config, layer_id, rope).eval()
static_attn = StaticAttention(config, layer_id, rope).eval()
static_attn.load_weights_from_attention_mha(attn_mha)

x = torch.rand(1, config.max_seq_len, config.dim)
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
mask = torch.triu(
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
diagonal=1,
)
y, _ = static_attn(
x,
freqs_cos.unsqueeze(0),
freqs_sin.unsqueeze(0),
mask=mask,
)
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())

def test_with_cache(self):
config = ModelArgs(
dim=64,
Expand Down
Loading