Skip to content

Commit 8f3a83b

Browse files
swolchokfacebook-github-bot
authored andcommitted
Fix rope source transformation error message (#5630)
Summary: Pull Request resolved: #5630 The error message here accessed attention incorrectly. Reviewed By: RylanC24, larryliu0820 Differential Revision: D63394466 fbshipit-source-id: 6e3d859ba20f27c26fc44b53630fbcd0eb947561
1 parent bdaeede commit 8f3a83b

File tree

1 file changed

+4
-4
lines changed
  • examples/models/llama2/source_transformation

1 file changed

+4
-4
lines changed

examples/models/llama2/source_transformation/rope.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ def materialze_broadcast_of_rope_freq_cis(
1616
assert module.freqs_cos.dim() == 2
1717
dim0 = module.freqs_cos.size(0)
1818
dim1 = module.freqs_cos.size(1)
19+
module_attention = module.layers[0].attention
1920
assert (
20-
module.layers[0].attention.n_local_kv_heads
21-
== module.layers[0].attention.n_local_heads
22-
), f"For rope freqs to be materialzed for broadcast q, k, v num heads must match. For q got {module.attention.n_kv_heads} for k got {module.attention.n_local_heads} and v got {module.attention.n_local_kv_heads}"
23-
num_heads = module.layers[0].attention.n_local_heads
21+
module_attention.n_local_kv_heads == module_attention.n_local_heads
22+
), f"For rope freqs to be materialized for broadcast, q, k, v num heads must match. For q got {module_attention.n_kv_heads} for k got {module_attention.n_local_heads} and v got {module_attention.n_local_kv_heads}"
23+
num_heads = module_attention.n_local_heads
2424
module.freqs_cos = module.freqs_cos.view(dim0, 1, dim1)
2525
module.freqs_cos = module.freqs_cos.expand(dim0, num_heads, dim1).contiguous()
2626
assert module.freqs_sin.dim() == 2

0 commit comments

Comments
 (0)