Skip to content

Commit a728a0d

Browse files
committed
llama: make MEM_REQ_EVAL depend on n_ctx
1 parent 5c6eed3 commit a728a0d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

llama.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,14 @@ static const std::map<e_model, size_t> & MEM_REQ_KV_SELF()
119119

120120
// this is mostly needed for temporary mul_mat buffers to dequantize the data
121121
// not actually needed if BLAS is disabled
122-
static const std::map<e_model, size_t> & MEM_REQ_EVAL()
122+
static const std::map<e_model, size_t> & MEM_REQ_EVAL(int n_ctx)
123123
{
124124
static std::map<e_model, size_t> k_sizes = {
125-
{ MODEL_3B, 640ull * MB },
126-
{ MODEL_7B, 768ull * MB },
127-
{ MODEL_13B, 1024ull * MB },
128-
{ MODEL_30B, 1280ull * MB },
129-
{ MODEL_65B, 1536ull * MB },
125+
{ MODEL_3B, ((size_t) n_ctx / 256ull + 512ull) * MB },
126+
{ MODEL_7B, ((size_t) n_ctx / 256ull + 768ull) * MB },
127+
{ MODEL_13B, ((size_t) n_ctx / 256ull + 1024ull) * MB },
128+
{ MODEL_30B, ((size_t) n_ctx / 256ull + 1280ull) * MB },
129+
{ MODEL_65B, ((size_t) n_ctx / 256ull + 1536ull) * MB },
130130
};
131131
return k_sizes;
132132
}
@@ -1140,7 +1140,7 @@ static void llama_model_load_internal(
11401140
mmapped_size - vram_weights + // weights in VRAM not in memory
11411141
MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
11421142
MEM_REQ_SCRATCH1().at(model.type) +
1143-
MEM_REQ_EVAL().at (model.type);
1143+
MEM_REQ_EVAL(hparams.n_ctx).at(model.type);
11441144

11451145
// this is the memory required by one llama_state
11461146
const size_t mem_required_state =
@@ -2652,7 +2652,7 @@ struct llama_context * llama_new_context_with_model(
26522652
ctx->embedding.resize(hparams.n_embd);
26532653
}
26542654

2655-
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type));
2655+
ctx->buf_compute.resize(MEM_REQ_EVAL(hparams.n_ctx).at(ctx->model.type));
26562656

26572657
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
26582658
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));

0 commit comments

Comments
 (0)