Skip to content

Commit 61200ef

Browse files
committed
llama : fix edge case finding batch seq_id of split recurrent cell
This otherwise was a problem when running the HellaSwag benchmark with small batch sizes, making it crash.
1 parent 18d1c14 commit 61200ef

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

llama.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3879,11 +3879,17 @@ static bool llama_cache_find_slot(
38793879
if (cell.tail_rc == 0) {
38803880
cache.rs.clear_cell(cell);
38813881
} else {
3882-
// TODO: does this always work correctly
3883-
// even if there are more than one seq_node in this cell?
3882+
// Find the seq_id of the first tail of this cell
3883+
llama_seq_id seq_id = -1;
3884+
for (llama_rs_seq_node & seq_node : cell.seq_nodes) {
3885+
if (seq_node.is_tail()) {
3886+
seq_id = seq_node.seq_id;
3887+
break;
3888+
}
3889+
}
3890+
GGML_ASSERT(seq_id != -1);
38843891

38853892
// Which seq_id of the batch is it?
3886-
llama_seq_id seq_id = cell.seq_nodes[0].seq_id;
38873893
int32_t nth_seq_id = -1;
38883894
for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) {
38893895
if (seq_id == batch.seq_id[s][0]) {

0 commit comments

Comments
 (0)