|
7 | 7 |
|
8 | 8 | # Different RoPE implementations
|
9 | 9 |
|
| 10 | +import math |
10 | 11 | from typing import Tuple
|
11 | 12 |
|
12 | 13 | import torch
|
13 | 14 |
|
14 | 15 | # ======================== Stock Implementation ========================
|
15 | 16 |
|
16 | 17 |
|
17 |
| -def precompute_freqs_cis(dim: int, end: int, theta: float): |
| 18 | +def apply_scaling(freqs: torch.Tensor): |
| 19 | + # Values obtained from grid search |
| 20 | + scale_factor = 8 |
| 21 | + low_freq_factor = 1 |
| 22 | + high_freq_factor = 4 |
| 23 | + old_context_len = 8192 # original llama3 length |
| 24 | + |
| 25 | + low_freq_wavelen = old_context_len / low_freq_factor |
| 26 | + high_freq_wavelen = old_context_len / high_freq_factor |
| 27 | + new_freqs = [] |
| 28 | + for freq in freqs: |
| 29 | + wavelen = 2 * math.pi / freq |
| 30 | + if wavelen < high_freq_wavelen: |
| 31 | + new_freqs.append(freq) |
| 32 | + elif wavelen > low_freq_wavelen: |
| 33 | + new_freqs.append(freq / scale_factor) |
| 34 | + else: |
| 35 | + assert low_freq_wavelen != high_freq_wavelen |
| 36 | + smooth = (old_context_len / wavelen - low_freq_factor) / ( |
| 37 | + high_freq_factor - low_freq_factor |
| 38 | + ) |
| 39 | + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) |
| 40 | + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) |
| 41 | + |
| 42 | + |
| 43 | +def precompute_freqs_cis( |
| 44 | + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False |
| 45 | +): |
18 | 46 | freqs = 1.0 / (
|
19 | 47 | theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
|
20 | 48 | )
|
21 | 49 | t = torch.arange(end, device=freqs.device) # pyre-ignore
|
22 |
| - freqs = torch.outer(t, freqs).float() # pyre-ignore |
| 50 | + if use_scaled: |
| 51 | + freqs = apply_scaling(freqs) # pyre-ignore |
| 52 | + freqs = torch.outer(t, freqs).float() |
23 | 53 | freqs_cos = torch.cos(freqs)
|
24 | 54 | freqs_sin = torch.sin(freqs)
|
25 | 55 | return freqs_cos, freqs_sin
|
|
0 commit comments