@@ -181,14 +181,13 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
181
181
// default hparams (LLaMA 7B)
182
182
struct llama_hparams {
183
183
uint32_t n_vocab = 32000 ;
184
- uint32_t n_vocab_sp = 0 ;
184
+ uint32_t n_vocab_base = 32000 ;
185
185
uint32_t n_ctx = 512 ; // this is provided as user input?
186
186
uint32_t n_embd = 4096 ;
187
187
uint32_t n_mult = 256 ;
188
188
uint32_t n_head = 32 ;
189
189
uint32_t n_head_kv = 32 ;
190
190
uint32_t n_layer = 32 ;
191
- uint32_t n_rot = 64 ;
192
191
193
192
// LLaMAv2
194
193
// TODO: load from model data hparams
@@ -499,7 +498,6 @@ enum llama_file_version {
499
498
LLAMA_FILE_VERSION_GGJT_V1, // added padding
500
499
LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format
501
500
LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format
502
- LLAMA_FILE_VERSION_GGJT_V4, // improved support for added/special tokens
503
501
};
504
502
505
503
struct llama_file_loader {
@@ -515,6 +513,7 @@ struct llama_file_loader {
515
513
read_hparams ();
516
514
read_vocab ();
517
515
read_tensor_metadata (tensors_map);
516
+ set_vocab_sp ();
518
517
}
519
518
void read_magic () {
520
519
uint32_t magic = file.read_u32 ();
@@ -537,7 +536,6 @@ struct llama_file_loader {
537
536
case 1 : file_version = LLAMA_FILE_VERSION_GGJT_V1; return ;
538
537
case 2 : file_version = LLAMA_FILE_VERSION_GGJT_V2; return ;
539
538
case 3 : file_version = LLAMA_FILE_VERSION_GGJT_V3; return ;
540
- case 4 : file_version = LLAMA_FILE_VERSION_GGJT_V4; return ;
541
539
}
542
540
}
543
541
@@ -546,18 +544,18 @@ struct llama_file_loader {
546
544
}
547
545
void read_hparams () {
548
546
hparams.n_vocab = file.read_u32 ();
549
- hparams.n_vocab_sp = file_version >= LLAMA_FILE_VERSION_GGJT_V4 ? file.read_u32 () : 0 ;
550
547
hparams.n_embd = file.read_u32 ();
551
548
hparams.n_mult = file.read_u32 ();
552
549
hparams.n_head = file.read_u32 ();
553
550
hparams.n_layer = file.read_u32 ();
554
- hparams.n_rot = file.read_u32 ();
551
+ hparams.n_vocab_base = file.read_u32 ();
552
+ hparams.n_vocab_base = (hparams.n_vocab_base & 0xF0000000 ) == 0 ? hparams.n_vocab : (hparams.n_vocab_base & ~0xF0000000 ); // this bitwise operation is necessary for compatibility with older models
555
553
hparams.ftype = (enum llama_ftype) file.read_u32 ();
556
554
557
555
// LLaMAv2
558
556
// TODO: read from header
559
557
hparams.n_head_kv = hparams.n_head ;
560
- }
558
+ =======
561
559
void read_vocab () {
562
560
vocab.id_to_token .resize (hparams.n_vocab );
563
561
@@ -574,20 +572,6 @@ struct llama_file_loader {
574
572
tok_score.tok = std::move (word);
575
573
tok_score.score = score;
576
574
}
577
-
578
- vocab.special_token_to_id .reserve (hparams.n_vocab_sp );
579
-
580
- for (uint32_t i = 0 ; i < hparams.n_vocab_sp ; i++) {
581
- llama_vocab::id token_id = file.read_u32 ();
582
- const auto & word = vocab.id_to_token [token_id].tok ;
583
-
584
- vocab.special_token_trie .add (word);
585
- vocab.special_token_to_id [word] = token_id;
586
-
587
- if (vocab.max_special_token_length < word.size ()) {
588
- vocab.max_special_token_length = word.size ();
589
- }
590
- }
591
575
}
592
576
void read_tensor_metadata (llama_load_tensors_map & tensors_map) {
593
577
while (file.tell () < file.size ) {
@@ -634,6 +618,24 @@ struct llama_file_loader {
634
618
tensors_map.name_to_idx [name] = tensors_map.tensors .size () - 1 ;
635
619
}
636
620
}
621
+ void set_vocab_sp () {
622
+ uint32_t vocab_sp = 3 + hparams.n_vocab - hparams.n_vocab_base ;
623
+ vocab.special_token_to_id .reserve (vocab_sp);
624
+ for (uint32_t i = 0 ; i < vocab_sp; i++) {
625
+ llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
626
+ const auto & word = vocab.id_to_token [token_id].tok ;
627
+ if (word.empty ()) {
628
+ continue ;
629
+ }
630
+
631
+ vocab.special_token_trie .add (word);
632
+ vocab.special_token_to_id [word] = token_id;
633
+
634
+ if (vocab.max_special_token_length < word.size ()) {
635
+ vocab.max_special_token_length = word.size ();
636
+ }
637
+ }
638
+ }
637
639
};
638
640
639
641
struct llama_file_saver {
@@ -653,12 +655,11 @@ struct llama_file_saver {
653
655
void write_hparams (enum llama_ftype new_ftype) {
654
656
const llama_hparams & hparams = any_file_loader->hparams ;
655
657
file.write_u32 (hparams.n_vocab );
656
- file.write_u32 (hparams.n_vocab_sp );
657
658
file.write_u32 (hparams.n_embd );
658
659
file.write_u32 (hparams.n_mult );
659
660
file.write_u32 (hparams.n_head );
660
661
file.write_u32 (hparams.n_layer );
661
- file.write_u32 (hparams.n_rot );
662
+ file.write_u32 (hparams.n_vocab_base | 0xF0000000 ); // this bitwise operation is necessary for compatibility with older models
662
663
file.write_u32 (new_ftype);
663
664
}
664
665
void write_vocab () {
@@ -672,9 +673,6 @@ struct llama_file_saver {
672
673
file.write_raw (token_score.tok .data (), token_score.tok .size ());
673
674
file.write_raw (&token_score.score , sizeof (token_score.score ));
674
675
}
675
- for (const auto & pair : any_file_loader->vocab .special_token_to_id ) {
676
- file.write_u32 (pair.second );
677
- }
678
676
}
679
677
void write_tensor (llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) {
680
678
switch (new_type) {
@@ -1001,8 +999,7 @@ static const char *llama_file_version_name(llama_file_version version) {
1001
999
case LLAMA_FILE_VERSION_GGMF_V1: return " ggmf v1 (old version with no mmap support)" ;
1002
1000
case LLAMA_FILE_VERSION_GGJT_V1: return " ggjt v1 (pre #1405)" ;
1003
1001
case LLAMA_FILE_VERSION_GGJT_V2: return " ggjt v2 (pre #1508)" ;
1004
- case LLAMA_FILE_VERSION_GGJT_V3: return " ggjt v3 (pre #1931)" ;
1005
- case LLAMA_FILE_VERSION_GGJT_V4: return " ggjt v4 (latest)" ;
1002
+ case LLAMA_FILE_VERSION_GGJT_V3: return " ggjt v3 (latest)" ;
1006
1003
}
1007
1004
1008
1005
return " unknown" ;
@@ -1127,7 +1124,7 @@ static void llama_model_load_internal(
1127
1124
fprintf (stderr, " %s: n_head = %u\n " , __func__, hparams.n_head );
1128
1125
fprintf (stderr, " %s: n_head_kv = %u\n " , __func__, hparams.n_head_kv );
1129
1126
fprintf (stderr, " %s: n_layer = %u\n " , __func__, hparams.n_layer );
1130
- fprintf (stderr, " %s: n_rot = %u\n " , __func__, hparams.n_rot ); // a.k.a. n_embd_head, n_head_dim
1127
+ fprintf (stderr, " %s: n_rot = %u\n " , __func__, hparams.n_embd /hparams. n_head ); // a.k.a. n_embd_head, n_head_dim
1131
1128
fprintf (stderr, " %s: n_gqa = %u\n " , __func__, hparams.n_gqa ());
1132
1129
fprintf (stderr, " %s: rnorm_eps = %.1e\n " , __func__, hparams.f_rms_norm_eps );
1133
1130
fprintf (stderr, " %s: n_ff = %u\n " , __func__, n_ff);
0 commit comments