@@ -64,6 +64,33 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
64
64
}
65
65
}
66
66
67
+ // return a list of splits for a given path
68
+ // for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
69
+ static std::vector<std::string> llama_get_list_splits (const std::string & path, const int idx, const int n_split) {
70
+ std::vector<std::string> paths;
71
+ std::string split_prefix;
72
+ std::vector<char > buf (llama_path_max (), 0 );
73
+
74
+ {
75
+ int ret = llama_split_prefix (buf.data (), buf.size (), path.c_str (), idx, n_split);
76
+ if (!ret) {
77
+ throw std::runtime_error (format (" invalid split file name: %s" , path.c_str ()));
78
+ }
79
+ split_prefix = std::string (buf.data (), ret);
80
+ }
81
+
82
+ if (split_prefix.empty ()) {
83
+ throw std::runtime_error (format (" invalid split file: %s" , path.c_str ()));
84
+ }
85
+
86
+ for (int idx = 0 ; idx < n_split; ++idx) {
87
+ int ret = llama_split_path (buf.data (), buf.size (), split_prefix.c_str (), idx, n_split);
88
+ paths.push_back (std::string (buf.data (), ret));
89
+ }
90
+
91
+ return paths;
92
+ }
93
+
67
94
namespace GGUFMeta {
68
95
template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t )>
69
96
struct GKV_Base_Type {
@@ -413,7 +440,12 @@ namespace GGUFMeta {
413
440
template bool llama_model_loader::get_key_or_arr<std::array<int , 4 >>(enum llm_kv kid, std::array<int , 4 > & result, uint32_t n, bool required);
414
441
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t , 512 >>(enum llm_kv kid, std::array<uint32_t , 512 > & result, uint32_t n, bool required);
415
442
416
- llama_model_loader::llama_model_loader (const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
443
+ llama_model_loader::llama_model_loader (
444
+ const std::string & fname,
445
+ std::vector<std::string> & splits,
446
+ bool use_mmap,
447
+ bool check_tensors,
448
+ const struct llama_model_kv_override * param_overrides_p) {
417
449
int trace = 0 ;
418
450
if (getenv (" LLAMA_TRACE" )) {
419
451
trace = atoi (getenv (" LLAMA_TRACE" ));
@@ -425,6 +457,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
425
457
}
426
458
}
427
459
460
+ // Load the main GGUF
428
461
struct ggml_context * ctx = NULL ;
429
462
struct gguf_init_params params = {
430
463
/* .no_alloc = */ true ,
@@ -460,35 +493,54 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
460
493
461
494
// Load additional GGML contexts
462
495
if (n_split > 1 ) {
496
+ // make sure the main file is loaded first
463
497
uint16_t idx = 0 ;
464
- get_key (llm_kv (LLM_KV_SPLIT_NO), idx);
498
+ const std::string kv_split_no = llm_kv (LLM_KV_SPLIT_NO);
499
+ get_key (kv_split_no, idx);
465
500
if (idx != 0 ) {
466
- throw std::runtime_error (format (" illegal split file: %d, model must be loaded with the first split" , idx));
501
+ throw std::runtime_error (format (" illegal split file idx: %d (file: %s), model must be loaded with the first split" , idx, fname.c_str ()));
502
+ }
503
+
504
+ // generate list of splits if needed
505
+ if (splits.empty ()) {
506
+ splits = llama_get_list_splits (fname, idx, n_split);
467
507
}
468
508
469
- std::vector< char > split_prefix ( llama_path_max (), 0 );
470
- if (! llama_split_prefix (split_prefix. data (), split_prefix .size (), fname. c_str (), idx, n_split )) {
471
- throw std::runtime_error (format (" invalid split file : %s " , fname. c_str () ));
509
+ // in case user give a custom list of splits, check if it matches the expected number
510
+ if (n_split != ( uint16_t )splits .size ()) {
511
+ throw std::runtime_error (format (" invalid split count, given : %zu splits, but expected %d " , splits. size (), n_split ));
472
512
}
473
513
474
514
if (trace > 0 ) {
475
515
LLAMA_LOG_INFO (" %s: loading additional %d GGUFs\n " , __func__, n_split);
476
516
}
477
517
478
- std::vector< char > split_path ( llama_path_max (), 0 );
518
+ // load other splits
479
519
for (idx = 1 ; idx < n_split; idx++) {
480
- llama_split_path (split_path. data (), split_path. size (), split_prefix. data (), idx, n_split );
520
+ const char * fname_split = splits[idx]. c_str ( );
481
521
482
522
struct gguf_init_params split_params = {
483
523
/* .no_alloc = */ true ,
484
524
/* .ctx = */ &ctx,
485
525
};
486
- gguf_context_ptr ctx_gguf { gguf_init_from_file (split_path. data () , split_params) };
526
+ gguf_context_ptr ctx_gguf { gguf_init_from_file (fname_split , split_params) };
487
527
if (!ctx_gguf) {
488
- throw std::runtime_error (format (" %s: failed to load GGUF split from %s\n " , __func__, split_path.data ()));
528
+ throw std::runtime_error (format (" %s: failed to load GGUF split from %s\n " , __func__, fname_split));
529
+ }
530
+
531
+ // check idx
532
+ {
533
+ const int kid = gguf_find_key (ctx_gguf.get (), kv_split_no.c_str ());
534
+ if (kid < 0 ) {
535
+ throw std::runtime_error (format (" missing key %s in GGUF split %s" , kv_split_no.c_str (), fname_split));
536
+ }
537
+ int idx_gguf = gguf_get_val_u16 (ctx_gguf.get (), kid);
538
+ if (idx_gguf != idx) {
539
+ throw std::runtime_error (format (" invalid split file idx: %d (file: %s), expected %d" , idx_gguf, fname_split, idx));
540
+ }
489
541
}
490
542
491
- files.emplace_back (new llama_file (split_path. data () , " rb" ));
543
+ files.emplace_back (new llama_file (fname_split , " rb" ));
492
544
contexts.emplace_back (ctx);
493
545
494
546
// Save tensors data offset info of the shard.
0 commit comments