Skip to content

Commit c9c0b75

Browse files
committed
llama : simplify gguf_file_saver
1 parent 66ce19a commit c9c0b75

File tree

2 files changed

+31
-65
lines changed

2 files changed

+31
-65
lines changed

examples/gguf/gguf-llama-simple.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ int main(int argc, char ** argv) {
7474
// tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
7575
// example, we will just stop the loop once this cache is full or once an end of stream is detected.
7676

77-
while (llama_get_kv_cache_token_count(ctx) < max_context_size) {
77+
const int n_gen = std::min(32, max_context_size);
78+
79+
while (llama_get_kv_cache_token_count(ctx) < n_gen) {
7880
// evaluate the transformer
7981

8082
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) {
@@ -114,13 +116,14 @@ int main(int argc, char ** argv) {
114116

115117
// push this new token for next evaluation
116118
tokens_list.push_back(new_token_id);
117-
118119
}
119120

120121
llama_free(ctx);
121122
llama_free_model(model);
122123

123124
llama_backend_free();
124125

126+
fprintf(stderr, "\n\n");
127+
125128
return 0;
126129
}

gguf-llama.cpp

Lines changed: 26 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -701,11 +701,11 @@ struct gguf_file_saver {
701701
size_t info_offset;
702702
size_t tensor_offset = 0;
703703

704-
gguf_file_saver(const char * fname, gguf_file_loader * fl, enum llama_ftype new_ftype)
704+
gguf_file_saver(const char * fname, gguf_file_loader * fl)
705705
: file(fname, "wb"), fl(fl) {
706706
fprintf(stderr, "llama.cpp: saving model to %s\n", fname);
707707
write_header();
708-
write_hparams(new_ftype);
708+
write_kv();
709709
}
710710

711711
void write_header() {
@@ -744,75 +744,38 @@ struct gguf_file_saver {
744744
file.write_arr<float>(key, type, data);
745745
}
746746

747-
void write_hparams(enum llama_ftype new_ftype) {
747+
// re-write the key-value section from the loaded file
748+
void write_kv() {
748749
const int32_t n_kv = gguf_get_n_kv(fl->gguf_ctx);
749750
for (int i = 0; i < n_kv; ++i) {
750751
const char * key = gguf_get_key(fl->gguf_ctx, i);
751752
if (strcmp(key, "general.quantization_version") == 0) {
752-
file.write_val<uint32_t>("general.quantization_version", GGUF_TYPE_UINT32, new_ftype);
753+
file.write_val<uint32_t>("general.quantization_version", GGUF_TYPE_UINT32, GGML_QNT_VERSION);
753754
} else {
754755
const gguf_type vtype = gguf_get_kv_type(fl->gguf_ctx, i);
755756

756-
bool bool_val;
757-
float f32_val;
758-
int16_t i16_val;
759-
int32_t i32_val;
760-
int8_t i8_val;
761-
std::string str_val;
762-
uint16_t u16_val;
763-
uint32_t u32_val;
764-
uint8_t u8_val;
765-
gguf_type arr_type;
766-
int n_arr;
767-
768757
switch (vtype) {
769-
case GGUF_TYPE_BOOL:
770-
bool_val = gguf_get_val_bool(fl->gguf_ctx, i);
771-
file.write_val<bool>(key, GGUF_TYPE_BOOL, bool_val);
772-
break;
773-
case GGUF_TYPE_FLOAT32:
774-
f32_val = gguf_get_val_f32(fl->gguf_ctx, i);
775-
file.write_val<float>(key, GGUF_TYPE_FLOAT32, f32_val);
776-
break;
777-
case GGUF_TYPE_INT16:
778-
i16_val = gguf_get_val_i16(fl->gguf_ctx, i);
779-
file.write_val<int16_t>(key, GGUF_TYPE_INT16, i16_val);
780-
break;
781-
case GGUF_TYPE_INT32:
782-
i32_val = gguf_get_val_i32(fl->gguf_ctx, i);
783-
file.write_val<int32_t>(key, GGUF_TYPE_INT32, i32_val);
784-
break;
785-
case GGUF_TYPE_INT8:
786-
i8_val = gguf_get_val_i8(fl->gguf_ctx, i);
787-
file.write_val<int8_t>(key, GGUF_TYPE_INT8, i8_val);
788-
break;
789-
case GGUF_TYPE_STRING:
790-
str_val = gguf_get_val_str(fl->gguf_ctx, i);
791-
file.write_str(key, GGUF_TYPE_STRING, str_val);
792-
break;
793-
case GGUF_TYPE_UINT16:
794-
u16_val = gguf_get_val_u16(fl->gguf_ctx, i);
795-
file.write_val<uint16_t>(key, GGUF_TYPE_UINT16, u16_val);
796-
break;
797-
case GGUF_TYPE_UINT32:
798-
u32_val = gguf_get_val_u32(fl->gguf_ctx, i);
799-
file.write_val<uint32_t>(key, GGUF_TYPE_UINT32, u32_val);
800-
break;
801-
case GGUF_TYPE_UINT8:
802-
u8_val = gguf_get_val_u8(fl->gguf_ctx, i);
803-
file.write_val<uint8_t>(key, GGUF_TYPE_UINT8, u8_val);
804-
break;
758+
case GGUF_TYPE_BOOL: file.write_val<bool> (key, GGUF_TYPE_BOOL, gguf_get_val_bool(fl->gguf_ctx, i)); break;
759+
case GGUF_TYPE_FLOAT32: file.write_val<float> (key, GGUF_TYPE_FLOAT32, gguf_get_val_f32 (fl->gguf_ctx, i)); break;
760+
case GGUF_TYPE_INT16: file.write_val<int16_t> (key, GGUF_TYPE_INT16, gguf_get_val_i16 (fl->gguf_ctx, i)); break;
761+
case GGUF_TYPE_INT32: file.write_val<int32_t> (key, GGUF_TYPE_INT32, gguf_get_val_i32 (fl->gguf_ctx, i)); break;
762+
case GGUF_TYPE_INT8: file.write_val<int8_t> (key, GGUF_TYPE_INT8, gguf_get_val_i8 (fl->gguf_ctx, i)); break;
763+
case GGUF_TYPE_STRING: file.write_str (key, GGUF_TYPE_STRING, gguf_get_val_str (fl->gguf_ctx, i)); break;
764+
case GGUF_TYPE_UINT16: file.write_val<uint16_t>(key, GGUF_TYPE_UINT16, gguf_get_val_u16 (fl->gguf_ctx, i)); break;
765+
case GGUF_TYPE_UINT32: file.write_val<uint32_t>(key, GGUF_TYPE_UINT32, gguf_get_val_u32 (fl->gguf_ctx, i)); break;
766+
case GGUF_TYPE_UINT8: file.write_val<uint8_t> (key, GGUF_TYPE_UINT8, gguf_get_val_u8 (fl->gguf_ctx, i)); break;
805767
case GGUF_TYPE_ARRAY:
806-
arr_type = gguf_get_arr_type(fl->gguf_ctx, i);
807-
n_arr = gguf_get_arr_n (fl->gguf_ctx, i);
808-
if (arr_type == GGUF_TYPE_FLOAT32) {
809-
write_hparam_arr_f32(key, arr_type, i, n_arr);
810-
} else if (arr_type == GGUF_TYPE_STRING) {
811-
write_hparam_arr_str(key, GGUF_TYPE_STRING, i, n_arr);
812-
} else {
813-
throw std::runtime_error("not implemented");
814-
}
815-
break;
768+
{
769+
const gguf_type arr_type = gguf_get_arr_type(fl->gguf_ctx, i);
770+
const int n_arr = gguf_get_arr_n (fl->gguf_ctx, i);
771+
if (arr_type == GGUF_TYPE_FLOAT32) {
772+
write_hparam_arr_f32(key, arr_type, i, n_arr);
773+
} else if (arr_type == GGUF_TYPE_STRING) {
774+
write_hparam_arr_str(key, arr_type, i, n_arr);
775+
} else {
776+
throw std::runtime_error("not implemented");
777+
}
778+
} break;
816779
default:
817780
throw std::runtime_error(format("cannot recognize value type for key %s\n", key));
818781
}
@@ -3264,7 +3227,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
32643227
}
32653228

32663229
std::unique_ptr<llama_model_loader> model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false));
3267-
gguf_file_saver file_saver(fname_out.c_str(), model_loader->file_loader.get(), params->ftype);
3230+
gguf_file_saver file_saver(fname_out.c_str(), model_loader->file_loader.get());
32683231

32693232
#ifdef GGML_USE_K_QUANTS
32703233
int n_attention_wv = 0;

0 commit comments

Comments
 (0)