Skip to content

Commit 332f073

Browse files
committed
cont : support non-continuous slots
ggml-ci
1 parent 39d0b1e commit 332f073

File tree

2 files changed

+46
-29
lines changed

2 files changed

+46
-29
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,11 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
400400
bool success = true;
401401

402402
for (const auto & ubatch : ubatches) {
403+
// non-continuous slots require support for ggml_set_rows()
404+
const bool cont = supports_set_rows ? false : true;
405+
403406
// only find a suitable slot for the ubatch. don't modify the cells yet
404-
const auto sinfo_new = find_slot(ubatch);
407+
const auto sinfo_new = find_slot(ubatch, cont);
405408
if (sinfo_new.empty()) {
406409
success = false;
407410
break;
@@ -521,7 +524,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
521524
return updated;
522525
}
523526

524-
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
527+
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
525528
const uint32_t n_tokens = ubatch.n_tokens;
526529

527530
uint32_t head_cur = this->head;
@@ -595,17 +598,25 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
595598
}
596599
}
597600

601+
uint32_t n_found = 0;
598602
uint32_t n_tested = 0;
599603

604+
const uint32_t n_test = cont ? n_tokens : 1;
605+
606+
slot_info res;
607+
608+
res.idxs.resize(n_tokens);
609+
600610
while (true) {
601-
if (head_cur + n_tokens > cells.size()) {
611+
if (head_cur + n_test > cells.size()) {
602612
n_tested += cells.size() - head_cur;
603613
head_cur = 0;
604614
continue;
605615
}
606616

607-
bool found = true;
608-
for (uint32_t i = 0; i < n_tokens; i++) {
617+
for (uint32_t i = 0; i < n_test; i++) {
618+
const auto idx = head_cur;
619+
609620
//const llama_pos pos = ubatch.pos[i];
610621
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
611622

@@ -615,19 +626,19 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
615626
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
616627
// - mask SWA, using current max pos for that sequence in the cache
617628
// always insert in the cell with minimum pos
618-
bool can_use = cells.is_empty(head_cur + i);
629+
bool can_use = cells.is_empty(idx);
619630

620-
if (!can_use && cells.seq_count(head_cur + i) == 1) {
621-
const llama_pos pos_cell = cells.pos_get(head_cur + i);
631+
if (!can_use && cells.seq_count(idx) == 1) {
632+
const llama_pos pos_cell = cells.pos_get(idx);
622633

623634
// (disabled) causal mask
624635
// note: it's better to purge any "future" tokens beforehand
625-
//if (cells.seq_has(head_cur + i, seq_id)) {
636+
//if (cells.seq_has(idx, seq_id)) {
626637
// can_use = pos_cell >= pos;
627638
//}
628639

629640
if (!can_use) {
630-
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
641+
const llama_seq_id seq_id_cell = cells.seq_get(idx);
631642

632643
// SWA mask
633644
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
@@ -636,29 +647,35 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
636647
}
637648
}
638649

639-
if (!can_use) {
640-
found = false;
641-
head_cur += i + 1;
642-
n_tested += i + 1;
650+
head_cur++;
651+
n_tested++;
652+
653+
if (can_use) {
654+
res.idxs[n_found] = idx;
655+
656+
n_found++;
657+
} else {
643658
break;
644659
}
645660
}
646661

647-
if (found) {
662+
if (n_found == n_tokens) {
648663
break;
649664
}
650665

666+
if (cont) {
667+
n_found = 0;
668+
}
669+
651670
if (n_tested >= cells.size()) {
652671
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
653672
return { };
654673
}
655674
}
656675

657-
slot_info res;
658-
659-
res.idxs.resize(n_tokens);
660-
for (uint32_t i = 0; i < n_tokens; ++i) {
661-
res.idxs[i] = head_cur + i;
676+
// we didn't find a suitable slot - return empty result
677+
if (n_found < n_tokens) {
678+
res.clear();
662679
}
663680

664681
return res;
@@ -1592,7 +1609,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15921609
ubatch.seq_id[i] = &dest_seq_id;
15931610
}
15941611

1595-
const auto sinfo = find_slot(ubatch);
1612+
const auto sinfo = find_slot(ubatch, true);
15961613
if (sinfo.empty()) {
15971614
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
15981615
return false;

src/llama-kv-cache-unified.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class llama_kv_cache_unified : public llama_memory_i {
4949
return idxs.empty();
5050
}
5151

52+
void clear() {
53+
idxs.clear();
54+
}
55+
5256
// TODO: implement
5357
//std::vector<idx_vec_t> seq_idxs;
5458
};
@@ -133,14 +137,10 @@ class llama_kv_cache_unified : public llama_memory_i {
133137

134138
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
135139

136-
// find a continuous slot of kv cells that can hold the ubatch
137-
// return the cell position where we can insert the ubatch
138-
// return -1 on failure to find a slot
139-
slot_info find_slot(const llama_ubatch & ubatch) const;
140-
141-
// find a set of kv cells that can hold the ubatch
142-
// TODO: implement
143-
//slot_info find_slot_ext(const llama_ubatch & ubatch) const;
140+
// find a slot of kv cells that can hold the ubatch
141+
// if cont == true, then the slot must be continuous
142+
// return empty slot_info on failure
143+
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
144144

145145
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
146146
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);

0 commit comments

Comments
 (0)