Skip to content

Add high_freq_factor to ModelArgs #10348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ModelArgs:
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
rope_scale_factor: int = 8
high_freq_factor: int = 4
# Additional Model Metadata needed at runtime
bos_idx: int = 1
eos_idx: int = 3
Expand Down
7 changes: 4 additions & 3 deletions examples/models/llama/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
# ======================== Stock Implementation ========================


def apply_scaling(freqs: torch.Tensor, scale_factor: int):
def apply_scaling(freqs: torch.Tensor, scale_factor: int, high_freq_factor: int):
# Values obtained from grid search
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length

low_freq_wavelen = old_context_len / low_freq_factor
Expand All @@ -47,14 +46,15 @@ def precompute_freqs_cis(
theta: float = 10000.0,
use_scaled: bool = False,
scale_factor: Optional[int] = None,
high_freq_factor: int = 4,
):
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
)
t = torch.arange(end, device=freqs.device) # pyre-ignore
if use_scaled:
assert scale_factor is not None
freqs = apply_scaling(freqs, scale_factor) # pyre-ignore
freqs = apply_scaling(freqs, scale_factor, high_freq_factor) # pyre-ignore
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
Expand Down Expand Up @@ -242,6 +242,7 @@ def __init__(self, params: ModelArgs):
precompute_freqs_cis,
use_scaled=self.params.use_scaled_rope,
scale_factor=self.params.rope_scale_factor,
high_freq_factor=self.params.high_freq_factor,
)
self.apply_rotary_emb = RotaryEmbedding()

Expand Down
Loading