Skip to content

Commit 3eac583

Browse files
authored
Add high_freq_factor to ModelArgs
Differential Revision: D73418899 Pull Request resolved: #10348
1 parent 095722b commit 3eac583

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

examples/models/llama/model_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ModelArgs:
4646
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
4747
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
4848
rope_scale_factor: int = 8
49+
high_freq_factor: int = 4
4950
# Additional Model Metadata needed at runtime
5051
bos_idx: int = 1
5152
eos_idx: int = 3

examples/models/llama/rope.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717
# ======================== Stock Implementation ========================
1818

1919

20-
def apply_scaling(freqs: torch.Tensor, scale_factor: int):
20+
def apply_scaling(freqs: torch.Tensor, scale_factor: int, high_freq_factor: int):
2121
# Values obtained from grid search
2222
low_freq_factor = 1
23-
high_freq_factor = 4
2423
old_context_len = 8192 # original llama3 length
2524

2625
low_freq_wavelen = old_context_len / low_freq_factor
@@ -47,14 +46,15 @@ def precompute_freqs_cis(
4746
theta: float = 10000.0,
4847
use_scaled: bool = False,
4948
scale_factor: Optional[int] = None,
49+
high_freq_factor: int = 4,
5050
):
5151
freqs = 1.0 / (
5252
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
5353
)
5454
t = torch.arange(end, device=freqs.device) # pyre-ignore
5555
if use_scaled:
5656
assert scale_factor is not None
57-
freqs = apply_scaling(freqs, scale_factor) # pyre-ignore
57+
freqs = apply_scaling(freqs, scale_factor, high_freq_factor) # pyre-ignore
5858
freqs = torch.outer(t, freqs).float()
5959
freqs_cos = torch.cos(freqs)
6060
freqs_sin = torch.sin(freqs)
@@ -242,6 +242,7 @@ def __init__(self, params: ModelArgs):
242242
precompute_freqs_cis,
243243
use_scaled=self.params.use_scaled_rope,
244244
scale_factor=self.params.rope_scale_factor,
245+
high_freq_factor=self.params.high_freq_factor,
245246
)
246247
self.apply_rotary_emb = RotaryEmbedding()
247248

0 commit comments

Comments
 (0)