@@ -86,7 +86,7 @@ struct llama_model {
86
86
};
87
87
88
88
// load the model's weights from a file
89
- bool llama_model_load (const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
89
+ bool llama_model_load (const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32 ) {
90
90
fprintf (stderr, " %s: loading model from '%s' - please wait ...\n " , __func__, fname.c_str ());
91
91
92
92
std::vector<char > f_buf (1024 *1024 );
@@ -209,8 +209,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
209
209
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); // w2
210
210
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); // w3
211
211
212
- ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (GGML_TYPE_F16 ); // memory_k
213
- ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (GGML_TYPE_F16 ); // memory_v
212
+ ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (memory_type ); // memory_k
213
+ ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (memory_type ); // memory_v
214
214
215
215
ctx_size += (5 + 10 *n_layer)*256 ; // object overhead
216
216
@@ -296,8 +296,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
296
296
const int n_mem = n_layer*n_ctx;
297
297
const int n_elements = n_embd*n_mem;
298
298
299
- model.memory_k = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
300
- model.memory_v = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
299
+ model.memory_k = ggml_new_tensor_1d (ctx, memory_type , n_elements);
300
+ model.memory_v = ggml_new_tensor_1d (ctx, memory_type , n_elements);
301
301
302
302
const size_t memory_size = ggml_nbytes (model.memory_k ) + ggml_nbytes (model.memory_v );
303
303
@@ -819,8 +819,9 @@ int main(int argc, char ** argv) {
819
819
820
820
// load the model
821
821
{
822
+ const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
822
823
const int64_t t_start_us = ggml_time_us ();
823
- if (!llama_model_load (params.model , model, vocab, params.n_ctx )) {
824
+ if (!llama_model_load (params.model , model, vocab, params.n_ctx , memory_type )) {
824
825
fprintf (stderr, " %s: failed to load model from '%s'\n " , __func__, params.model .c_str ());
825
826
return 1 ;
826
827
}
0 commit comments