@@ -102,9 +102,6 @@ struct llama_context {
102
102
// decode output (2-dimensional array: [n_tokens][n_vocab])
103
103
std::vector<float > logits;
104
104
bool logits_all = false ;
105
-
106
- // work buffer for transformer evaluation
107
- std::vector<uint8_t > buf_eval;
108
105
};
109
106
110
107
struct llama_context_params llama_context_default_params () {
@@ -630,19 +627,27 @@ static bool llama_eval_internal(
630
627
const int n_rot = hparams.n_embd /hparams.n_head ;
631
628
632
629
auto & mem_per_token = lctx.mem_per_token ;
633
- auto & buf_eval = lctx.buf_eval ;
634
630
635
- if (mem_per_token*(n_past + N + 16 ) > buf_eval.size ()) {
636
- const size_t buf_size_new = 1.618 *buf_eval.size ();
631
+ // TODO: fix this hardcoded size
632
+ static size_t buf_size = 512u *1024 *1024 ;
633
+ static void * buf = malloc (buf_size);
637
634
638
- // fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_eval.size(), buf_size_new);
635
+ if (mem_per_token > 0 && mem_per_token*N > buf_size) {
636
+ const size_t buf_size_new = 1.3 *(mem_per_token*N); // add 30% to account for ggml object overhead
637
+ // fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
639
638
640
- buf_eval.resize (buf_size_new);
639
+ // reallocate
640
+ buf_size = buf_size_new;
641
+ buf = realloc (buf, buf_size);
642
+ if (buf == nullptr ) {
643
+ fprintf (stderr, " %s: failed to allocate %zu bytes\n " , __func__, buf_size);
644
+ return false ;
645
+ }
641
646
}
642
647
643
648
struct ggml_init_params params = {
644
- /* .mem_size =*/ buf_eval. size () ,
645
- /* .mem_buffer =*/ buf_eval. data () ,
649
+ /* .mem_size =*/ buf_size ,
650
+ /* .mem_buffer =*/ buf ,
646
651
};
647
652
648
653
struct ggml_context * ctx0 = ggml_init (params);
@@ -827,11 +832,10 @@ static bool llama_eval_internal(
827
832
memcpy (logits_out.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
828
833
}
829
834
830
- if (N == 1 ) {
831
- mem_per_token = ggml_used_mem (ctx0)/(n_past + N) ;
835
+ if (mem_per_token == 0 ) {
836
+ mem_per_token = ggml_used_mem (ctx0)/N ;
832
837
}
833
-
834
- // fprintf(stderr, "\nused_mem = %zu, %zu MB\n", ggml_used_mem(ctx0), ggml_used_mem(ctx0)/1024/1024);
838
+ // fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
835
839
836
840
ggml_free (ctx0);
837
841
@@ -1412,8 +1416,6 @@ struct llama_context * llama_init_from_file(
1412
1416
return nullptr ;
1413
1417
}
1414
1418
1415
- ctx->buf_eval .resize (512u *1024u *1024u );
1416
-
1417
1419
return ctx;
1418
1420
}
1419
1421
0 commit comments