Skip to content

Commit 5a2c5d1

Browse files
committed
Support llama3.1
Summary: Add scaled RoPE Test Plan: Test official checkpoint and gives meaningful result. Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 1e29544 Pull Request resolved: #4376
1 parent dbc73a6 commit 5a2c5d1

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class ModelArgs:
101101
None # The official name to override self.rope_freq_base.
102102
)
103103
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
104+
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
104105
# Additional Model Metadata needed at runtime
105106
bos_idx: int = 1
106107
eos_idx: int = 3
@@ -462,6 +463,7 @@ def __init__(self, params: ModelArgs):
462463
else params.max_seq_len * 2 # Sharded checkpoint.
463464
),
464465
params.rope_freq_base,
466+
params.use_scaled_rope,
465467
)
466468
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
467469
self.register_buffer("freqs_sin", freqs_sin, persistent=False)

examples/models/llama2/rope.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,46 @@
88
# Different RoPE implementations
99

1010
from typing import Tuple
11-
11+
import math
1212
import torch
1313

1414
# ======================== Stock Implementation ========================
1515

1616

17-
def precompute_freqs_cis(dim: int, end: int, theta: float):
17+
def apply_scaling(freqs: torch.Tensor):
18+
# Values obtained from grid search
19+
scale_factor = 8
20+
low_freq_factor = 1
21+
high_freq_factor = 4
22+
old_context_len = 8192 # original llama3 length
23+
24+
low_freq_wavelen = old_context_len / low_freq_factor
25+
high_freq_wavelen = old_context_len / high_freq_factor
26+
new_freqs = []
27+
for freq in freqs:
28+
wavelen = 2 * math.pi / freq
29+
if wavelen < high_freq_wavelen:
30+
new_freqs.append(freq)
31+
elif wavelen > low_freq_wavelen:
32+
new_freqs.append(freq / scale_factor)
33+
else:
34+
assert low_freq_wavelen != high_freq_wavelen
35+
smooth = (old_context_len / wavelen - low_freq_factor) / (
36+
high_freq_factor - low_freq_factor
37+
)
38+
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
39+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
40+
41+
42+
def precompute_freqs_cis(
43+
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
44+
):
1845
freqs = 1.0 / (
1946
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
2047
)
2148
t = torch.arange(end, device=freqs.device) # pyre-ignore
49+
if use_scaled:
50+
freqs = apply_scaling(freqs)
2251
freqs = torch.outer(t, freqs).float() # pyre-ignore
2352
freqs_cos = torch.cos(freqs)
2453
freqs_sin = torch.sin(freqs)

0 commit comments

Comments
 (0)