@@ -119,14 +119,14 @@ static const std::map<e_model, size_t> & MEM_REQ_KV_SELF()
119
119
120
120
// this is mostly needed for temporary mul_mat buffers to dequantize the data
121
121
// not actually needed if BLAS is disabled
122
- static const std::map<e_model, size_t > & MEM_REQ_EVAL ()
122
+ static const std::map<e_model, size_t > & MEM_REQ_EVAL (int n_ctx )
123
123
{
124
124
static std::map<e_model, size_t > k_sizes = {
125
- { MODEL_3B, 640ull * MB },
126
- { MODEL_7B, 768ull * MB },
127
- { MODEL_13B, 1024ull * MB },
128
- { MODEL_30B, 1280ull * MB },
129
- { MODEL_65B, 1536ull * MB },
125
+ { MODEL_3B, (( size_t ) n_ctx / 256ull + 512ull ) * MB },
126
+ { MODEL_7B, (( size_t ) n_ctx / 256ull + 768ull ) * MB },
127
+ { MODEL_13B, (( size_t ) n_ctx / 256ull + 1024ull ) * MB },
128
+ { MODEL_30B, (( size_t ) n_ctx / 256ull + 1280ull ) * MB },
129
+ { MODEL_65B, (( size_t ) n_ctx / 256ull + 1536ull ) * MB },
130
130
};
131
131
return k_sizes;
132
132
}
@@ -1140,7 +1140,7 @@ static void llama_model_load_internal(
1140
1140
mmapped_size - vram_weights + // weights in VRAM not in memory
1141
1141
MEM_REQ_SCRATCH0 (hparams.n_ctx ).at (model.type ) +
1142
1142
MEM_REQ_SCRATCH1 ().at (model.type ) +
1143
- MEM_REQ_EVAL ().at (model.type );
1143
+ MEM_REQ_EVAL (hparams. n_ctx ).at (model.type );
1144
1144
1145
1145
// this is the memory required by one llama_state
1146
1146
const size_t mem_required_state =
@@ -2652,7 +2652,7 @@ struct llama_context * llama_new_context_with_model(
2652
2652
ctx->embedding .resize (hparams.n_embd );
2653
2653
}
2654
2654
2655
- ctx->buf_compute .resize (MEM_REQ_EVAL ().at (ctx->model .type ));
2655
+ ctx->buf_compute .resize (MEM_REQ_EVAL (hparams. n_ctx ).at (ctx->model .type ));
2656
2656
2657
2657
ctx->buf_scratch [0 ].resize (MEM_REQ_SCRATCH0 (hparams.n_ctx ).at (ctx->model .type ));
2658
2658
ctx->buf_scratch [1 ].resize (MEM_REQ_SCRATCH1 ().at (ctx->model .type ));
0 commit comments