Skip to content

Commit 5e3d46e

Browse files
authored
Update llama_transformer.py
1 parent 61c722c commit 5e3d46e

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class ModelArgs:
8888
use_sdpa_with_kv_cache_op: bool = (
8989
False # Use custom sdpa op that updates kv cache in-place
9090
)
91-
rope_theta: float = 10000.0 # The base frequency for RoPE
91+
rope_theta: float = None # The official name to override self.rope_freq_base.
92+
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
9293
# Additional Model Metadata needed at runtime
9394
bos_idx: int = 1
9495
eos_idx: int = 3
@@ -99,6 +100,10 @@ def __post_init__(self):
99100
if self.n_kv_heads is None:
100101
self.n_kv_heads = self.n_heads
101102

103+
# rope_theta overrides rope_freq_base since it's the official name.
104+
if self.rope_theta is not None:
105+
self.rope_freq_base = self.rope_theta
106+
102107
if self.use_sdpa_with_kv_cache_op:
103108
assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache"
104109

@@ -448,7 +453,7 @@ def __init__(self, params: ModelArgs):
448453
if params.ffn_dim_multiplier is None
449454
else params.max_seq_len * 2 # Sharded checkpoint.
450455
),
451-
params.rope_theta,
456+
params.rope_freq_base,
452457
)
453458
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
454459
self.register_buffer("freqs_sin", freqs_sin, persistent=False)

0 commit comments

Comments
 (0)