Skip to content

Commit 61c722c

Browse files
authored
Update model arg name rope_theta to be consistent with those in llama's website
As title
1 parent 2c467dd commit 61c722c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class ModelArgs:
8888
use_sdpa_with_kv_cache_op: bool = (
8989
False # Use custom sdpa op that updates kv cache in-place
9090
)
91-
rope_freq_base: float = 10000.0 # The base frequency for RoPE
91+
rope_theta: float = 10000.0 # The base frequency for RoPE
9292
# Additional Model Metadata needed at runtime
9393
bos_idx: int = 1
9494
eos_idx: int = 3
@@ -448,7 +448,7 @@ def __init__(self, params: ModelArgs):
448448
if params.ffn_dim_multiplier is None
449449
else params.max_seq_len * 2 # Sharded checkpoint.
450450
),
451-
params.rope_freq_base,
451+
params.rope_theta,
452452
)
453453
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
454454
self.register_buffer("freqs_sin", freqs_sin, persistent=False)

0 commit comments

Comments
 (0)