Skip to content

Commit 14554a8

Browse files
committed
cont : kv-cells cp/set for non-cont slots
ggml-ci
1 parent d4be34b commit 14554a8

File tree

3 files changed

+64
-32
lines changed

3 files changed

+64
-32
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
388388

389389
struct state {
390390
uint32_t head_old; // old position of the head, before placing the ubatch
391-
uint32_t head_new; // new position of the head, after placing the ubatch
391+
392+
slot_info sinfo; // slot info for the ubatch
392393

393394
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
394395
};
@@ -409,21 +410,16 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
409410
// remeber the position that we found
410411
res.push_back(sinfo_new);
411412

412-
// TODO: temporary
413-
if (supports_set_rows) {
414-
GGML_ASSERT(sinfo_new.is_cont());
415-
}
416-
417413
// store the old state of the cells in the recovery stack
418-
states.push_back({head, sinfo_new.head(), cells.cp(sinfo_new.head(), ubatch.n_tokens)});
414+
states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
419415

420416
// now emplace the ubatch
421417
apply_ubatch(sinfo_new, ubatch);
422418
}
423419

424420
// iterate backwards and restore the cells to their original state
425421
for (auto it = states.rbegin(); it != states.rend(); ++it) {
426-
cells.set(it->head_new, it->cells);
422+
cells.set(it->sinfo.idxs, it->cells);
427423
head = it->head_old;
428424
}
429425

src/llama-kv-cache-unified.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,6 @@ class llama_kv_cache_unified : public llama_memory_i {
4949
return idxs.empty();
5050
}
5151

52-
// TODO: tmp until kv cells support non-cont slots
53-
bool is_cont() const {
54-
bool res = true;
55-
56-
for (uint32_t i = 1; i < idxs.size(); ++i) {
57-
if (idxs[i] != idxs[i - 1] + 1) {
58-
res = false;
59-
break;
60-
}
61-
}
62-
63-
return res;
64-
}
65-
6652
// TODO: implement
6753
//std::vector<idx_vec_t> seq_idxs;
6854
};

src/llama-kv-cells.h

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,29 @@ class llama_kv_cells_unified {
104104
res.resize(n);
105105

106106
for (uint32_t j = 0; j < n; ++j) {
107-
res.pos[j] = pos[i + j];
108-
res.seq[j] = seq[i + j];
107+
const auto idx = i + j;
109108

110-
assert(shift[i + j] == 0);
109+
res.pos[j] = pos[idx];
110+
res.seq[j] = seq[idx];
111+
112+
assert(shift[idx] == 0);
113+
}
114+
115+
return res;
116+
}
117+
118+
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
119+
llama_kv_cells_unified res;
120+
121+
res.resize(idxs.size());
122+
123+
for (uint32_t j = 0; j < idxs.size(); ++j) {
124+
const auto idx = idxs[j];
125+
126+
res.pos[j] = pos[idx];
127+
res.seq[j] = seq[idx];
128+
129+
assert(shift[idx] == 0);
111130
}
112131

113132
return res;
@@ -118,26 +137,57 @@ class llama_kv_cells_unified {
118137
assert(i + other.pos.size() <= pos.size());
119138

120139
for (uint32_t j = 0; j < other.pos.size(); ++j) {
121-
if (pos[i + j] == -1 && other.pos[j] != -1) {
140+
const auto idx = i + j;
141+
142+
if (pos[idx] == -1 && other.pos[j] != -1) {
122143
used.insert(i + j);
123144
}
124145

125-
if (pos[i + j] != -1 && other.pos[j] == -1) {
146+
if (pos[idx] != -1 && other.pos[j] == -1) {
126147
used.erase(i + j);
127148
}
128149

129-
if (pos[i + j] != -1) {
150+
if (pos[idx] != -1) {
130151
seq_pos_rm(i + j);
131152
}
132153

133-
pos[i + j] = other.pos[j];
134-
seq[i + j] = other.seq[j];
154+
pos[idx] = other.pos[j];
155+
seq[idx] = other.seq[j];
135156

136-
if (pos[i + j] != -1) {
157+
if (pos[idx] != -1) {
137158
seq_pos_add(i + j);
138159
}
139160

140-
assert(shift[i + j] == 0);
161+
assert(shift[idx] == 0);
162+
}
163+
}
164+
165+
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
166+
assert(idxs.size() == other.pos.size());
167+
168+
for (uint32_t j = 0; j < other.pos.size(); ++j) {
169+
const auto idx = idxs[j];
170+
171+
if (pos[idx] == -1 && other.pos[j] != -1) {
172+
used.insert(idx);
173+
}
174+
175+
if (pos[idx] != -1 && other.pos[j] == -1) {
176+
used.erase(idx);
177+
}
178+
179+
if (pos[idx] != -1) {
180+
seq_pos_rm(idx);
181+
}
182+
183+
pos[idx] = other.pos[j];
184+
seq[idx] = other.seq[j];
185+
186+
if (pos[idx] != -1) {
187+
seq_pos_add(idx);
188+
}
189+
190+
assert(shift[idx] == 0);
141191
}
142192
}
143193

0 commit comments

Comments
 (0)