@@ -88,7 +88,8 @@ class ModelArgs:
88
88
use_sdpa_with_kv_cache_op : bool = (
89
89
False # Use custom sdpa op that updates kv cache in-place
90
90
)
91
- rope_theta : float = 10000.0 # The base frequency for RoPE
91
+ rope_theta : float = None # The official name to override self.rope_freq_base.
92
+ rope_freq_base : float = 10000.0 # The base frequency for RoPE. Keep it for BC.
92
93
# Additional Model Metadata needed at runtime
93
94
bos_idx : int = 1
94
95
eos_idx : int = 3
@@ -99,6 +100,10 @@ def __post_init__(self):
99
100
if self .n_kv_heads is None :
100
101
self .n_kv_heads = self .n_heads
101
102
103
+ # rope_theta overrides rope_freq_base since it's the official name.
104
+ if self .rope_theta is not None :
105
+ self .rope_freq_base = self .rope_theta
106
+
102
107
if self .use_sdpa_with_kv_cache_op :
103
108
assert self .use_kv_cache , "use_sdpa_with_kv_cache_op requires use_kv_cache"
104
109
@@ -448,7 +453,7 @@ def __init__(self, params: ModelArgs):
448
453
if params .ffn_dim_multiplier is None
449
454
else params .max_seq_len * 2 # Sharded checkpoint.
450
455
),
451
- params .rope_theta ,
456
+ params .rope_freq_base ,
452
457
)
453
458
self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
454
459
self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
0 commit comments