Skip to content

Commit 4b20567

Browse files
committed
Make llama_kv_cache copyable+moveable.
1 parent 6a2ac4f commit 4b20567

File tree

1 file changed

+59
-15
lines changed

1 file changed

+59
-15
lines changed

llama.cpp

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -246,26 +246,67 @@ struct llama_layer {
246246
struct ggml_tensor * w3;
247247
};
248248

249-
struct llama_kv_cache {
249+
class llama_kv_cache {
250+
// Hide ctx as it requires a custom deleter ggml_free.
251+
std::shared_ptr<ggml_context> ctx;
252+
public:
253+
250254
struct ggml_tensor * k = NULL;
251255
struct ggml_tensor * v = NULL;
252256

253-
struct ggml_context * ctx = NULL;
254-
255257
llama_ctx_buffer buf;
256258

257259
int n; // number of tokens currently in the cache
258260

259-
~llama_kv_cache() {
260-
if (ctx) {
261-
ggml_free(ctx);
262-
}
261+
ggml_context* get_ctx() { return ctx.get(); }
262+
ggml_context const* get_ctx() const { return ctx.get(); }
263+
void set_ctx(ggml_context* ctx) {
264+
this->ctx = std::shared_ptr<ggml_context>(ctx, ggml_free);
265+
}
263266

267+
llama_kv_cache() = default;
268+
~llama_kv_cache() {
264269
#ifdef GGML_USE_CUBLAS
265270
ggml_cuda_free_data(k);
266271
ggml_cuda_free_data(v);
267272
#endif // GGML_USE_CUBLAS
268273
}
274+
llama_kv_cache(llama_kv_cache const& rhs)
275+
: ctx(rhs.ctx.get(), ggml_free)
276+
, k(ggml_dup_tensor(rhs.ctx.get(), rhs.k))
277+
, v(ggml_dup_tensor(rhs.ctx.get(), rhs.v))
278+
, buf(rhs.buf)
279+
, n(rhs.n)
280+
{ }
281+
llama_kv_cache& operator=(llama_kv_cache const& rhs) {
282+
this->~llama_kv_cache();
283+
ctx = rhs.ctx;
284+
k = rhs.k ? ggml_dup_tensor(rhs.ctx.get(), rhs.k) : NULL;
285+
v = rhs.v ? ggml_dup_tensor(rhs.ctx.get(), rhs.v) : NULL;
286+
buf = rhs.buf;
287+
n = rhs.n;
288+
return *this;
289+
}
290+
llama_kv_cache(llama_kv_cache&& rhs)
291+
: ctx(std::move(rhs.ctx))
292+
, k(rhs.k)
293+
, v(rhs.v)
294+
, buf(std::move(rhs.buf))
295+
, n(rhs.n)
296+
{
297+
rhs.k = NULL;
298+
rhs.v = NULL;
299+
}
300+
llama_kv_cache& operator=(llama_kv_cache&& rhs) {
301+
this->~llama_kv_cache();
302+
ctx = std::move(rhs.ctx);
303+
std::swap(k, rhs.k);
304+
std::swap(v, rhs.v);
305+
buf = std::move(rhs.buf);
306+
n = rhs.n;
307+
return *this;
308+
}
309+
269310
};
270311

271312
struct llama_vocab {
@@ -863,15 +904,15 @@ static bool kv_cache_init(
863904
params.mem_buffer = cache.buf.addr;
864905
params.no_alloc = false;
865906

866-
cache.ctx = ggml_init(params);
907+
cache.set_ctx(ggml_init(params));
867908

868-
if (!cache.ctx) {
909+
if (!cache.get_ctx()) {
869910
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
870911
return false;
871912
}
872913

873-
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
874-
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
914+
cache.k = ggml_new_tensor_1d(cache.get_ctx(), wtype, n_elements);
915+
cache.v = ggml_new_tensor_1d(cache.get_ctx(), wtype, n_elements);
875916
ggml_set_name(cache.k, "cache_k");
876917
ggml_set_name(cache.v, "cache_v");
877918

@@ -1410,7 +1451,7 @@ static struct ggml_cgraph * llama_build_graph(
14101451

14111452
const auto & kv_self = lctx.kv_self;
14121453

1413-
LLAMA_ASSERT(!!kv_self.ctx);
1454+
LLAMA_ASSERT(!!kv_self.get_ctx());
14141455

14151456
const int64_t n_embd = hparams.n_embd;
14161457
const int64_t n_layer = hparams.n_layer;
@@ -2878,6 +2919,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
28782919
}
28792920

28802921
struct beam {
2922+
llama_kv_cache kv_cache;
28812923
std::vector<llama_token> tokens;
28822924
float p; // Cumulative beam probability (renormalized with each token)
28832925
// end-of-sentence
@@ -2948,13 +2990,15 @@ void fill_next_beams_by_top_probabilities(llama_context* ctx, std::vector<beam>&
29482990
}
29492991
} else if (next_beams.front().p < b.p) {
29502992
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
2951-
next_beams.back() = b;
2993+
next_beams.back() = std::move(b);
29522994
std::push_heap(next_beams.begin(), next_beams.end(), comp);
29532995
}
29542996
} else {
29552997
// b is not at end-of-sentence, so branch with next top_k tokens.
29562998
if (!b.tokens.empty()) {
2999+
std::swap(ctx->kv_self, const_cast<beam&>(b).kv_cache);
29573000
llama_eval(ctx, b.tokens.data(), b.tokens.size(), n_past, n_threads);
3001+
std::swap(ctx->kv_self, const_cast<beam&>(b).kv_cache);
29583002
}
29593003
logit_info li(ctx);
29603004
std::vector<llama_token_data> next_tokens = li.top_k(beam_width);
@@ -3006,11 +3050,11 @@ const char* llama_beam_search(llama_context * ctx, int const beam_width,
30063050

30073051
std::vector<beam> beams;
30083052
beams.reserve(beam_width);
3009-
beams.push_back({{}, 1.0});
3053+
beams.push_back({ctx->kv_self, {}, 1.0});
30103054
std::vector<beam> next_beams;
30113055
next_beams.reserve(beam_width);
30123056
// Loop while there are any beams that have not yet reached end-of-sentence.
3013-
// If the top beam is at end-of-sentence, then finish since all other
3057+
// If the highest probability beam is at end-of-sentence, then finish since all other
30143058
// beam probabilities can only decrease.
30153059
auto const eos = [](beam const& b) { return b.eos(); };
30163060
for (int i=0 ; i<n_predict && !eos(top_beam(beams)) &&

0 commit comments

Comments
 (0)