Skip to content

Commit 6c69ebd

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Support llama3.1 (#4376)
Summary: Pull Request resolved: #4376 Add scaled RoPE Reviewed By: Gasoonjia Differential Revision: D60129927 fbshipit-source-id: b8d2fadcd3e6985740965ad0185b8fb516806c22
1 parent 11b2fcb commit 6c69ebd

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# Please refer to README.md in the same folder for more information.
99

1010
from dataclasses import dataclass
11+
from functools import partial
1112
from typing import Optional, Tuple
1213

1314
import torch
@@ -101,6 +102,7 @@ class ModelArgs:
101102
None # The official name to override self.rope_freq_base.
102103
)
103104
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
105+
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
104106
# Additional Model Metadata needed at runtime
105107
bos_idx: int = 1
106108
eos_idx: int = 3
@@ -453,7 +455,9 @@ def __init__(self, params: ModelArgs):
453455
if params.use_hf_rope:
454456
self.precompute_freqs_cis = hf_precompute_freqs_cis
455457
else:
456-
self.precompute_freqs_cis = precompute_freqs_cis
458+
self.precompute_freqs_cis = partial(
459+
precompute_freqs_cis, use_scaled=params.use_scaled_rope
460+
)
457461
freqs_cos, freqs_sin = self.precompute_freqs_cis(
458462
params.dim // params.n_heads,
459463
(

examples/models/llama2/rope.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,49 @@
77

88
# Different RoPE implementations
99

10+
import math
1011
from typing import Tuple
1112

1213
import torch
1314

1415
# ======================== Stock Implementation ========================
1516

1617

17-
def precompute_freqs_cis(dim: int, end: int, theta: float):
18+
def apply_scaling(freqs: torch.Tensor):
19+
# Values obtained from grid search
20+
scale_factor = 8
21+
low_freq_factor = 1
22+
high_freq_factor = 4
23+
old_context_len = 8192 # original llama3 length
24+
25+
low_freq_wavelen = old_context_len / low_freq_factor
26+
high_freq_wavelen = old_context_len / high_freq_factor
27+
new_freqs = []
28+
for freq in freqs:
29+
wavelen = 2 * math.pi / freq
30+
if wavelen < high_freq_wavelen:
31+
new_freqs.append(freq)
32+
elif wavelen > low_freq_wavelen:
33+
new_freqs.append(freq / scale_factor)
34+
else:
35+
assert low_freq_wavelen != high_freq_wavelen
36+
smooth = (old_context_len / wavelen - low_freq_factor) / (
37+
high_freq_factor - low_freq_factor
38+
)
39+
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
40+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
41+
42+
43+
def precompute_freqs_cis(
44+
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
45+
):
1846
freqs = 1.0 / (
1947
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
2048
)
2149
t = torch.arange(end, device=freqs.device) # pyre-ignore
22-
freqs = torch.outer(t, freqs).float() # pyre-ignore
50+
if use_scaled:
51+
freqs = apply_scaling(freqs) # pyre-ignore
52+
freqs = torch.outer(t, freqs).float()
2353
freqs_cos = torch.cos(freqs)
2454
freqs_sin = torch.sin(freqs)
2555
return freqs_cos, freqs_sin

0 commit comments

Comments
 (0)