@@ -246,26 +246,67 @@ struct llama_layer {
246
246
struct ggml_tensor * w3;
247
247
};
248
248
249
- struct llama_kv_cache {
249
+ class llama_kv_cache {
250
+ // Hide ctx as it requires a custom deleter ggml_free.
251
+ std::shared_ptr<ggml_context> ctx;
252
+ public:
253
+
250
254
struct ggml_tensor * k = NULL ;
251
255
struct ggml_tensor * v = NULL ;
252
256
253
- struct ggml_context * ctx = NULL ;
254
-
255
257
llama_ctx_buffer buf;
256
258
257
259
int n; // number of tokens currently in the cache
258
260
259
- ~llama_kv_cache () {
260
- if (ctx) {
261
- ggml_free (ctx);
262
- }
261
+ ggml_context* get_ctx () { return ctx.get (); }
262
+ ggml_context const * get_ctx () const { return ctx.get (); }
263
+ void set_ctx (ggml_context* ctx) {
264
+ this ->ctx = std::shared_ptr<ggml_context>(ctx, ggml_free);
265
+ }
263
266
267
+ llama_kv_cache () = default ;
268
+ ~llama_kv_cache () {
264
269
#ifdef GGML_USE_CUBLAS
265
270
ggml_cuda_free_data (k);
266
271
ggml_cuda_free_data (v);
267
272
#endif // GGML_USE_CUBLAS
268
273
}
274
+ llama_kv_cache (llama_kv_cache const & rhs)
275
+ : ctx(rhs.ctx.get(), ggml_free)
276
+ , k(ggml_dup_tensor(rhs.ctx.get(), rhs.k))
277
+ , v(ggml_dup_tensor(rhs.ctx.get(), rhs.v))
278
+ , buf(rhs.buf)
279
+ , n(rhs.n)
280
+ { }
281
+ llama_kv_cache& operator =(llama_kv_cache const & rhs) {
282
+ this ->~llama_kv_cache ();
283
+ ctx = rhs.ctx ;
284
+ k = rhs.k ? ggml_dup_tensor (rhs.ctx .get (), rhs.k ) : NULL ;
285
+ v = rhs.v ? ggml_dup_tensor (rhs.ctx .get (), rhs.v ) : NULL ;
286
+ buf = rhs.buf ;
287
+ n = rhs.n ;
288
+ return *this ;
289
+ }
290
+ llama_kv_cache (llama_kv_cache&& rhs)
291
+ : ctx(std::move(rhs.ctx))
292
+ , k(rhs.k)
293
+ , v(rhs.v)
294
+ , buf(std::move(rhs.buf))
295
+ , n(rhs.n)
296
+ {
297
+ rhs.k = NULL ;
298
+ rhs.v = NULL ;
299
+ }
300
+ llama_kv_cache& operator =(llama_kv_cache&& rhs) {
301
+ this ->~llama_kv_cache ();
302
+ ctx = std::move (rhs.ctx );
303
+ std::swap (k, rhs.k );
304
+ std::swap (v, rhs.v );
305
+ buf = std::move (rhs.buf );
306
+ n = rhs.n ;
307
+ return *this ;
308
+ }
309
+
269
310
};
270
311
271
312
struct llama_vocab {
@@ -863,15 +904,15 @@ static bool kv_cache_init(
863
904
params.mem_buffer = cache.buf .addr ;
864
905
params.no_alloc = false ;
865
906
866
- cache.ctx = ggml_init (params);
907
+ cache.set_ctx ( ggml_init (params) );
867
908
868
- if (!cache.ctx ) {
909
+ if (!cache.get_ctx () ) {
869
910
fprintf (stderr, " %s: failed to allocate memory for kv cache\n " , __func__);
870
911
return false ;
871
912
}
872
913
873
- cache.k = ggml_new_tensor_1d (cache.ctx , wtype, n_elements);
874
- cache.v = ggml_new_tensor_1d (cache.ctx , wtype, n_elements);
914
+ cache.k = ggml_new_tensor_1d (cache.get_ctx () , wtype, n_elements);
915
+ cache.v = ggml_new_tensor_1d (cache.get_ctx () , wtype, n_elements);
875
916
ggml_set_name (cache.k , " cache_k" );
876
917
ggml_set_name (cache.v , " cache_v" );
877
918
@@ -1410,7 +1451,7 @@ static struct ggml_cgraph * llama_build_graph(
1410
1451
1411
1452
const auto & kv_self = lctx.kv_self ;
1412
1453
1413
- LLAMA_ASSERT (!!kv_self.ctx );
1454
+ LLAMA_ASSERT (!!kv_self.get_ctx () );
1414
1455
1415
1456
const int64_t n_embd = hparams.n_embd ;
1416
1457
const int64_t n_layer = hparams.n_layer ;
@@ -2878,6 +2919,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
2878
2919
}
2879
2920
2880
2921
struct beam {
2922
+ llama_kv_cache kv_cache;
2881
2923
std::vector<llama_token> tokens;
2882
2924
float p; // Cumulative beam probability (renormalized with each token)
2883
2925
// end-of-sentence
@@ -2948,13 +2990,15 @@ void fill_next_beams_by_top_probabilities(llama_context* ctx, std::vector<beam>&
2948
2990
}
2949
2991
} else if (next_beams.front ().p < b.p ) {
2950
2992
std::pop_heap (next_beams.begin (), next_beams.end (), comp);
2951
- next_beams.back () = b ;
2993
+ next_beams.back () = std::move (b) ;
2952
2994
std::push_heap (next_beams.begin (), next_beams.end (), comp);
2953
2995
}
2954
2996
} else {
2955
2997
// b is not at end-of-sentence, so branch with next top_k tokens.
2956
2998
if (!b.tokens .empty ()) {
2999
+ std::swap (ctx->kv_self , const_cast <beam&>(b).kv_cache );
2957
3000
llama_eval (ctx, b.tokens .data (), b.tokens .size (), n_past, n_threads);
3001
+ std::swap (ctx->kv_self , const_cast <beam&>(b).kv_cache );
2958
3002
}
2959
3003
logit_info li (ctx);
2960
3004
std::vector<llama_token_data> next_tokens = li.top_k (beam_width);
@@ -3006,11 +3050,11 @@ const char* llama_beam_search(llama_context * ctx, int const beam_width,
3006
3050
3007
3051
std::vector<beam> beams;
3008
3052
beams.reserve (beam_width);
3009
- beams.push_back ({{}, 1.0 });
3053
+ beams.push_back ({ctx-> kv_self , {}, 1.0 });
3010
3054
std::vector<beam> next_beams;
3011
3055
next_beams.reserve (beam_width);
3012
3056
// Loop while there are any beams that have not yet reached end-of-sentence.
3013
- // If the top beam is at end-of-sentence, then finish since all other
3057
+ // If the highest probability beam is at end-of-sentence, then finish since all other
3014
3058
// beam probabilities can only decrease.
3015
3059
auto const eos = [](beam const & b) { return b.eos (); };
3016
3060
for (int i=0 ; i<n_predict && !eos (top_beam (beams)) &&
0 commit comments