Skip to content

Commit d4be34b

Browse files
committed
cont : migrate to using set of indices instead of slot head
ggml-ci
1 parent 113c762 commit d4be34b

6 files changed

+143
-85
lines changed

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113113
ubatches.push_back(std::move(ubatch)); // NOLINT
114114
}
115115

116-
auto heads_base = kv_base->prepare(ubatches);
117-
if (heads_base.empty()) {
116+
auto sinfos_base = kv_base->prepare(ubatches);
117+
if (sinfos_base.empty()) {
118118
break;
119119
}
120120

121-
auto heads_swa = kv_swa->prepare(ubatches);
122-
if (heads_swa.empty()) {
121+
auto sinfos_swa = kv_swa->prepare(ubatches);
122+
if (sinfos_swa.empty()) {
123123
break;
124124
}
125125

126-
assert(heads_base.size() == heads_swa.size());
126+
assert(sinfos_base.size() == sinfos_swa.size());
127127

128128
return std::make_unique<llama_kv_cache_unified_iswa_context>(
129-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
129+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
130130
} while (false);
131131

132132
// if it fails, try equal split
@@ -144,20 +144,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144144
ubatches.push_back(std::move(ubatch)); // NOLINT
145145
}
146146

147-
auto heads_base = kv_base->prepare(ubatches);
148-
if (heads_base.empty()) {
147+
auto sinfos_base = kv_base->prepare(ubatches);
148+
if (sinfos_base.empty()) {
149149
break;
150150
}
151151

152-
auto heads_swa = kv_swa->prepare(ubatches);
153-
if (heads_swa.empty()) {
152+
auto sinfos_swa = kv_swa->prepare(ubatches);
153+
if (sinfos_swa.empty()) {
154154
break;
155155
}
156156

157-
assert(heads_base.size() == heads_swa.size());
157+
assert(sinfos_base.size() == sinfos_swa.size());
158158

159159
return std::make_unique<llama_kv_cache_unified_iswa_context>(
160-
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
160+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
161161
} while (false);
162162

163163
// TODO: if we fail again, we should attempt different splitting strategies
@@ -220,13 +220,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
220220

221221
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
222222
llama_kv_cache_unified_iswa * kv,
223-
std::vector<uint32_t> heads_base,
224-
std::vector<uint32_t> heads_swa,
223+
slot_info_vec_t sinfos_base,
224+
slot_info_vec_t sinfos_swa,
225225
std::vector<llama_ubatch> ubatches) :
226226
ubatches(std::move(ubatches)),
227227
// note: here we copy the ubatches. not sure if this is ideal
228-
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229-
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
228+
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
229+
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
230230
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
231231
}
232232

src/llama-kv-cache-unified-iswa.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
7474

7575
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
7676
public:
77+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
78+
7779
// used for errors
7880
llama_kv_cache_unified_iswa_context(llama_memory_status status);
7981

@@ -90,8 +92,8 @@ class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
9092
// used to create a batch processing context from a batch
9193
llama_kv_cache_unified_iswa_context(
9294
llama_kv_cache_unified_iswa * kv,
93-
std::vector<uint32_t> heads_base,
94-
std::vector<uint32_t> heads_swa,
95+
slot_info_vec_t sinfos_base,
96+
slot_info_vec_t sinfos_swa,
9597
std::vector<llama_ubatch> ubatches);
9698

9799
virtual ~llama_kv_cache_unified_iswa_context();

src/llama-kv-cache-unified.cpp

Lines changed: 64 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
334334
ubatches.push_back(std::move(ubatch)); // NOLINT
335335
}
336336

337-
auto heads = prepare(ubatches);
338-
if (heads.empty()) {
337+
auto sinfos = prepare(ubatches);
338+
if (sinfos.empty()) {
339339
break;
340340
}
341341

342342
return std::make_unique<llama_kv_cache_unified_context>(
343-
this, std::move(heads), std::move(ubatches));
343+
this, std::move(sinfos), std::move(ubatches));
344344
} while (false);
345345

346346
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -383,8 +383,8 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
383383
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
384384
}
385385

386-
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
387-
llama_kv_cache_unified::ubatch_heads res;
386+
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
387+
llama_kv_cache_unified::slot_info_vec_t res;
388388

389389
struct state {
390390
uint32_t head_old; // old position of the head, before placing the ubatch
@@ -400,20 +400,25 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
400400

401401
for (const auto & ubatch : ubatches) {
402402
// only find a suitable slot for the ubatch. don't modify the cells yet
403-
const int32_t head_new = find_slot(ubatch);
404-
if (head_new < 0) {
403+
const auto sinfo_new = find_slot(ubatch);
404+
if (sinfo_new.empty()) {
405405
success = false;
406406
break;
407407
}
408408

409409
// remeber the position that we found
410-
res.push_back(head_new);
410+
res.push_back(sinfo_new);
411+
412+
// TODO: temporary
413+
if (supports_set_rows) {
414+
GGML_ASSERT(sinfo_new.is_cont());
415+
}
411416

412417
// store the old state of the cells in the recovery stack
413-
states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
418+
states.push_back({head, sinfo_new.head(), cells.cp(sinfo_new.head(), ubatch.n_tokens)});
414419

415420
// now emplace the ubatch
416-
apply_ubatch(head_new, ubatch);
421+
apply_ubatch(sinfo_new, ubatch);
417422
}
418423

419424
// iterate backwards and restore the cells to their original state
@@ -520,7 +525,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
520525
return updated;
521526
}
522527

523-
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
528+
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
524529
const uint32_t n_tokens = ubatch.n_tokens;
525530

526531
uint32_t head_cur = this->head;
@@ -533,7 +538,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
533538

534539
if (n_tokens > cells.size()) {
535540
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
536-
return -1;
541+
return { };
537542
}
538543

539544
if (debug > 0) {
@@ -649,37 +654,48 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
649654

650655
if (n_tested >= cells.size()) {
651656
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
652-
return -1;
657+
return { };
653658
}
654659
}
655660

656-
return head_cur;
661+
slot_info res;
662+
663+
res.idxs.resize(n_tokens);
664+
for (uint32_t i = 0; i < n_tokens; ++i) {
665+
res.idxs[i] = head_cur + i;
666+
}
667+
668+
return res;
657669
}
658670

659-
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
671+
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
660672
// keep track of the max sequence position that we would overwrite with this ubatch
661673
// for non-SWA cache, this would be always empty
662674
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
663675
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
664676
seq_pos_max_rm[s] = -1;
665677
}
666678

679+
assert(ubatch.n_tokens == sinfo.idxs.size());
680+
667681
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
668-
if (!cells.is_empty(head_cur + i)) {
669-
assert(cells.seq_count(head_cur + i) == 1);
682+
const auto idx = sinfo.idxs[i];
683+
684+
if (!cells.is_empty(idx)) {
685+
assert(cells.seq_count(idx) == 1);
670686

671-
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
672-
const llama_pos pos = cells.pos_get(head_cur + i);
687+
const llama_seq_id seq_id = cells.seq_get(idx);
688+
const llama_pos pos = cells.pos_get(idx);
673689

674690
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
675691

676-
cells.rm(head_cur + i);
692+
cells.rm(idx);
677693
}
678694

679-
cells.pos_set(head_cur + i, ubatch.pos[i]);
695+
cells.pos_set(idx, ubatch.pos[i]);
680696

681697
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
682-
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
698+
cells.seq_add(idx, ubatch.seq_id[i][s]);
683699
}
684700
}
685701

@@ -700,7 +716,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
700716
}
701717

702718
// move the head at the end of the slot
703-
head = head_cur + ubatch.n_tokens;
719+
head = sinfo.idxs.back() + 1;
704720
}
705721

706722
bool llama_kv_cache_unified::get_can_shift() const {
@@ -753,7 +769,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
753769
0);
754770
}
755771

756-
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
772+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
757773
const int32_t ikv = map_layer_ids.at(il);
758774

759775
auto * k = layers[ikv].k;
@@ -772,12 +788,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
772788

773789
ggml_tensor * k_view = ggml_view_1d(ctx, k,
774790
n_tokens*n_embd_k_gqa,
775-
ggml_row_size(k->type, n_embd_k_gqa)*head_cur);
791+
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
776792

777793
return ggml_cpy(ctx, k_cur, k_view);
778794
}
779795

780-
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
796+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
781797
const int32_t ikv = map_layer_ids.at(il);
782798

783799
auto * v = layers[ikv].v;
@@ -814,19 +830,19 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
814830
if (!v_trans) {
815831
v_view = ggml_view_1d(ctx, v,
816832
n_tokens*n_embd_v_gqa,
817-
ggml_row_size(v->type, n_embd_v_gqa)*head_cur);
833+
ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
818834
} else {
819835
v_cur = ggml_transpose(ctx, v_cur);
820836

821837
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
822-
(v->ne[1])*ggml_element_size(v),
823-
(head_cur)*ggml_element_size(v));
838+
(v->ne[1] )*ggml_element_size(v),
839+
(sinfo.head())*ggml_element_size(v));
824840
}
825841

826842
return ggml_cpy(ctx, v_cur, v_view);
827843
}
828844

829-
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
845+
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
830846
if (!supports_set_rows) {
831847
return;
832848
}
@@ -837,7 +853,7 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
837853
int64_t * data = (int64_t *) dst->data;
838854

839855
for (int64_t i = 0; i < n_tokens; ++i) {
840-
data[i] = head_cur + i;
856+
data[i] = sinfo.idxs[i];
841857
}
842858
}
843859

@@ -1580,13 +1596,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15801596
ubatch.seq_id[i] = &dest_seq_id;
15811597
}
15821598

1583-
const auto head_cur = find_slot(ubatch);
1584-
if (head_cur < 0) {
1599+
const auto sinfo = find_slot(ubatch);
1600+
if (sinfo.empty()) {
15851601
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
15861602
return false;
15871603
}
15881604

1589-
apply_ubatch(head_cur, ubatch);
1605+
apply_ubatch(sinfo, ubatch);
1606+
1607+
const auto head_cur = sinfo.head();
15901608

15911609
// keep the head at the old position because we will read the KV data into it in state_read_data()
15921610
head = head_cur;
@@ -1772,7 +1790,10 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
17721790
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
17731791
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
17741792
n_kv = kv->get_size();
1775-
head = 0;
1793+
1794+
sinfos.resize(1);
1795+
sinfos[0].idxs.resize(1);
1796+
sinfos[0].idxs[0] = 0;
17761797
}
17771798

17781799
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
@@ -1787,16 +1808,16 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
17871808

17881809
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
17891810
llama_kv_cache_unified * kv,
1790-
llama_kv_cache_unified::ubatch_heads heads,
1791-
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1811+
llama_kv_cache_unified::slot_info_vec_t sinfos,
1812+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
17921813
}
17931814

17941815
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
17951816

17961817
bool llama_kv_cache_unified_context::next() {
17971818
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
17981819

1799-
if (++i_next >= ubatches.size()) {
1820+
if (++i_cur >= ubatches.size()) {
18001821
return false;
18011822
}
18021823

@@ -1813,10 +1834,9 @@ bool llama_kv_cache_unified_context::apply() {
18131834
return true;
18141835
}
18151836

1816-
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
1837+
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
18171838

18181839
n_kv = kv->get_n_kv();
1819-
head = heads[i_next];
18201840

18211841
return true;
18221842
}
@@ -1828,7 +1848,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
18281848
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
18291849
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
18301850

1831-
return ubatches[i_next];
1851+
return ubatches[i_cur];
18321852
}
18331853

18341854
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
@@ -1844,19 +1864,19 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
18441864
}
18451865

18461866
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1847-
return kv->cpy_k(ctx, k_cur, kv_idxs, il, head);
1867+
return kv->cpy_k(ctx, k_cur, kv_idxs, il, sinfos[i_cur]);
18481868
}
18491869

18501870
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1851-
return kv->cpy_v(ctx, v_cur, kv_idxs, il, head);
1871+
return kv->cpy_v(ctx, v_cur, kv_idxs, il, sinfos[i_cur]);
18521872
}
18531873

18541874
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
18551875
kv->set_input_k_shift(dst);
18561876
}
18571877

18581878
void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1859-
kv->set_input_kv_idxs(dst, ubatch, head);
1879+
kv->set_input_kv_idxs(dst, ubatch, sinfos[i_cur]);
18601880
}
18611881

18621882
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {

0 commit comments

Comments
 (0)