17
17
# ======================== Stock Implementation ========================
18
18
19
19
20
- def apply_scaling (freqs : torch .Tensor , scale_factor : int ):
20
+ def apply_scaling (freqs : torch .Tensor , scale_factor : int , high_freq_factor : int ):
21
21
# Values obtained from grid search
22
22
low_freq_factor = 1
23
- high_freq_factor = 4
24
23
old_context_len = 8192 # original llama3 length
25
24
26
25
low_freq_wavelen = old_context_len / low_freq_factor
@@ -47,14 +46,15 @@ def precompute_freqs_cis(
47
46
theta : float = 10000.0 ,
48
47
use_scaled : bool = False ,
49
48
scale_factor : Optional [int ] = None ,
49
+ high_freq_factor : int = 4 ,
50
50
):
51
51
freqs = 1.0 / (
52
52
theta ** (torch .arange (0 , dim , 2 , device = "cpu" )[: (dim // 2 )].float () / dim )
53
53
)
54
54
t = torch .arange (end , device = freqs .device ) # pyre-ignore
55
55
if use_scaled :
56
56
assert scale_factor is not None
57
- freqs = apply_scaling (freqs , scale_factor ) # pyre-ignore
57
+ freqs = apply_scaling (freqs , scale_factor , high_freq_factor ) # pyre-ignore
58
58
freqs = torch .outer (t , freqs ).float ()
59
59
freqs_cos = torch .cos (freqs )
60
60
freqs_sin = torch .sin (freqs )
@@ -242,6 +242,7 @@ def __init__(self, params: ModelArgs):
242
242
precompute_freqs_cis ,
243
243
use_scaled = self .params .use_scaled_rope ,
244
244
scale_factor = self .params .rope_scale_factor ,
245
+ high_freq_factor = self .params .high_freq_factor ,
245
246
)
246
247
self .apply_rotary_emb = RotaryEmbedding ()
247
248
0 commit comments