Skip to content

Commit 027ad54

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Fix a rotary position encoding bug in kv cache
Summary: We have 2 branches in `Transformer` module for using kv cache or not. For the branch that uses kv cache, we should get the rotary position encoding by slicing the precomputed value by the `start_pos: start_pos + seqlen`. This diff fixes it. Reviewed By: JacobSzwejbka Differential Revision: D53954747 fbshipit-source-id: d79ea06e97d5a5f06533e4e4db11f61e2a0fae87
1 parent f4c4ad3 commit 027ad54

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

examples/models/llama2/model.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,9 @@ def forward(
405405
) -> Union[
406406
torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]
407407
]:
408+
_bsz, seqlen = tokens.shape
409+
h = self.tok_embeddings(tokens)
410+
408411
if self.use_kv_cache:
409412
assert (
410413
cache_k is not None and cache_v is not None and start_pos is not None
@@ -415,29 +418,27 @@ def forward(
415418
assert (
416419
cache_v.size(0) == self.n_layers
417420
), f"{cache_v.size(0)} != {self.n_layers}"
418-
else:
419-
assert (
420-
start_pos is None and cache_k is None and cache_v is None,
421-
"Caches and start_pos are unused when use_kv_cache is False",
422-
)
423-
424-
_bsz, seqlen = tokens.shape
425-
h = self.tok_embeddings(tokens)
426-
freqs_cos = self.freqs_cos[:seqlen]
427-
freqs_sin = self.freqs_sin[:seqlen]
428421

429-
if self.use_kv_cache:
430-
sp = start_pos.item() # pyre-ignore[16]
422+
sp = start_pos.item()
431423
# self.params.max_seq_len - 1 because of 0 based indexing, and - 1 again because our input seq len is 1 and its added to the cache before accessing the cache
432424
torch._constrain_as_size(sp, min=0, max=self.params.max_seq_len - 2)
433425
torch._constrain_as_value(
434-
cache_k.shape[0], # pyre-ignore[16]
435-
min=self.n_layers,
426+
cache_k.shape[0],
436427
max=self.n_layers,
428+
min=self.n_layers,
437429
)
438430
torch._constrain_as_value(
439431
cache_v.shape[0], min=self.n_layers, max=self.n_layers
440432
)
433+
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
434+
freqs_cos = self.freqs_cos[sp : sp + seqlen]
435+
freqs_sin = self.freqs_sin[sp : sp + seqlen]
436+
else:
437+
assert (
438+
start_pos is None and cache_k is None and cache_v is None,
439+
), "Caches and start_pos are unused when use_kv_cache is False"
440+
freqs_cos = self.freqs_cos[:seqlen]
441+
freqs_sin = self.freqs_sin[:seqlen]
441442

442443
for index, layer in enumerate(self.layers):
443444
if self.use_kv_cache:

0 commit comments

Comments
 (0)