Skip to content

Commit 39d0b1e

Browse files
committed
cont : kv-cells cp/set for non-cont slots
ggml-ci
1 parent f875d6c commit 39d0b1e

File tree

3 files changed

+64
-32
lines changed

3 files changed

+64
-32
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
388388

389389
struct state {
390390
uint32_t head_old; // old position of the head, before placing the ubatch
391-
uint32_t head_new; // new position of the head, after placing the ubatch
391+
392+
slot_info sinfo; // slot info for the ubatch
392393

393394
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
394395
};
@@ -409,21 +410,16 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
409410
// remeber the position that we found
410411
res.push_back(sinfo_new);
411412

412-
// TODO: temporary
413-
if (supports_set_rows) {
414-
GGML_ASSERT(sinfo_new.is_cont());
415-
}
416-
417413
// store the old state of the cells in the recovery stack
418-
states.push_back({head, sinfo_new.head(), cells.cp(sinfo_new.head(), ubatch.n_tokens)});
414+
states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
419415

420416
// now emplace the ubatch
421417
apply_ubatch(sinfo_new, ubatch);
422418
}
423419

424420
// iterate backwards and restore the cells to their original state
425421
for (auto it = states.rbegin(); it != states.rend(); ++it) {
426-
cells.set(it->head_new, it->cells);
422+
cells.set(it->sinfo.idxs, it->cells);
427423
head = it->head_old;
428424
}
429425

src/llama-kv-cache-unified.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,6 @@ class llama_kv_cache_unified : public llama_memory_i {
4949
return idxs.empty();
5050
}
5151

52-
// TODO: tmp until kv cells support non-cont slots
53-
bool is_cont() const {
54-
bool res = true;
55-
56-
for (uint32_t i = 1; i < idxs.size(); ++i) {
57-
if (idxs[i] != idxs[i - 1] + 1) {
58-
res = false;
59-
break;
60-
}
61-
}
62-
63-
return res;
64-
}
65-
6652
// TODO: implement
6753
//std::vector<idx_vec_t> seq_idxs;
6854
};

src/llama-kv-cells.h

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,29 @@ class llama_kv_cells_unified {
105105
res.resize(n);
106106

107107
for (uint32_t j = 0; j < n; ++j) {
108-
res.pos[j] = pos[i + j];
109-
res.seq[j] = seq[i + j];
108+
const auto idx = i + j;
110109

111-
assert(shift[i + j] == 0);
110+
res.pos[j] = pos[idx];
111+
res.seq[j] = seq[idx];
112+
113+
assert(shift[idx] == 0);
114+
}
115+
116+
return res;
117+
}
118+
119+
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
120+
llama_kv_cells_unified res;
121+
122+
res.resize(idxs.size());
123+
124+
for (uint32_t j = 0; j < idxs.size(); ++j) {
125+
const auto idx = idxs[j];
126+
127+
res.pos[j] = pos[idx];
128+
res.seq[j] = seq[idx];
129+
130+
assert(shift[idx] == 0);
112131
}
113132

114133
return res;
@@ -119,26 +138,57 @@ class llama_kv_cells_unified {
119138
assert(i + other.pos.size() <= pos.size());
120139

121140
for (uint32_t j = 0; j < other.pos.size(); ++j) {
122-
if (pos[i + j] == -1 && other.pos[j] != -1) {
141+
const auto idx = i + j;
142+
143+
if (pos[idx] == -1 && other.pos[j] != -1) {
123144
used.insert(i + j);
124145
}
125146

126-
if (pos[i + j] != -1 && other.pos[j] == -1) {
147+
if (pos[idx] != -1 && other.pos[j] == -1) {
127148
used.erase(i + j);
128149
}
129150

130-
if (pos[i + j] != -1) {
151+
if (pos[idx] != -1) {
131152
seq_pos_rm(i + j);
132153
}
133154

134-
pos[i + j] = other.pos[j];
135-
seq[i + j] = other.seq[j];
155+
pos[idx] = other.pos[j];
156+
seq[idx] = other.seq[j];
136157

137-
if (pos[i + j] != -1) {
158+
if (pos[idx] != -1) {
138159
seq_pos_add(i + j);
139160
}
140161

141-
assert(shift[i + j] == 0);
162+
assert(shift[idx] == 0);
163+
}
164+
}
165+
166+
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
167+
assert(idxs.size() == other.pos.size());
168+
169+
for (uint32_t j = 0; j < other.pos.size(); ++j) {
170+
const auto idx = idxs[j];
171+
172+
if (pos[idx] == -1 && other.pos[j] != -1) {
173+
used.insert(idx);
174+
}
175+
176+
if (pos[idx] != -1 && other.pos[j] == -1) {
177+
used.erase(idx);
178+
}
179+
180+
if (pos[idx] != -1) {
181+
seq_pos_rm(idx);
182+
}
183+
184+
pos[idx] = other.pos[j];
185+
seq[idx] = other.seq[j];
186+
187+
if (pos[idx] != -1) {
188+
seq_pos_add(idx);
189+
}
190+
191+
assert(shift[idx] == 0);
142192
}
143193
}
144194

0 commit comments

Comments
 (0)