Skip to content

Commit ef31608

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Add sdpa arg comments (#5323)
Summary: Pull Request resolved: #5323 Reviewed By: JacobSzwejbka Differential Revision: D62623249 Pulled By: dvorjackz fbshipit-source-id: 468abd913a4dcb9b2474ec34881cfcec2654a024
1 parent 08f16d0 commit ef31608

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,9 @@ def __init__(
218218
def forward(
219219
self,
220220
input_pos: torch.Tensor,
221-
q: torch.Tensor,
222-
k: torch.Tensor,
223-
v: torch.Tensor,
221+
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
222+
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
223+
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
224224
bsz,
225225
seqlen,
226226
mask: torch.Tensor,

0 commit comments

Comments
 (0)