@@ -413,7 +413,12 @@ namespace GGUFMeta {
413
413
template bool llama_model_loader::get_key_or_arr<std::array<int , 4 >>(enum llm_kv kid, std::array<int , 4 > & result, uint32_t n, bool required);
414
414
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t , 512 >>(enum llm_kv kid, std::array<uint32_t , 512 > & result, uint32_t n, bool required);
415
415
416
- llama_model_loader::llama_model_loader (const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
416
+ llama_model_loader::llama_model_loader (
417
+ const std::string & fname,
418
+ std::vector<std::string> & splits,
419
+ bool use_mmap,
420
+ bool check_tensors,
421
+ const struct llama_model_kv_override * param_overrides_p) {
417
422
int trace = 0 ;
418
423
if (getenv (" LLAMA_TRACE" )) {
419
424
trace = atoi (getenv (" LLAMA_TRACE" ));
@@ -425,6 +430,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
425
430
}
426
431
}
427
432
433
+ // Load the main GGUF
428
434
struct ggml_context * ctx = NULL ;
429
435
struct gguf_init_params params = {
430
436
/* .no_alloc = */ true ,
@@ -460,35 +466,52 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
460
466
461
467
// Load additional GGML contexts
462
468
if (n_split > 1 ) {
463
- uint16_t idx = 0 ;
464
- get_key (llm_kv (LLM_KV_SPLIT_NO), idx);
465
- if (idx != 0 ) {
466
- throw std::runtime_error (format (" illegal split file: %d, model must be loaded with the first split" , idx));
469
+ // generate list of splits if needed
470
+ if (splits.empty ()) {
471
+ splits = llama_get_list_splits (fname, n_split);
467
472
}
468
473
469
- std::vector<char > split_prefix (llama_path_max (), 0 );
470
- if (!llama_split_prefix (split_prefix.data (), split_prefix.size (), fname.c_str (), idx, n_split)) {
471
- throw std::runtime_error (format (" invalid split file: %s" , fname.c_str ()));
474
+ // in case user give a custom list of splits, check if it matches the expected number
475
+ if (n_split != (uint16_t )splits.size ()) {
476
+ throw std::runtime_error (format (" invalid split count, given: %zu splits, but expected %d" , splits.size (), n_split));
477
+ }
478
+
479
+ uint16_t idx = 0 ;
480
+ const std::string kv_split_no = llm_kv (LLM_KV_SPLIT_NO);
481
+ get_key (kv_split_no, idx);
482
+ if (idx != 0 ) {
483
+ throw std::runtime_error (format (" illegal split file idx: %d (file: %s), model must be loaded with the first split" , idx, fname.c_str ()));
472
484
}
473
485
474
486
if (trace > 0 ) {
475
487
LLAMA_LOG_INFO (" %s: loading additional %d GGUFs\n " , __func__, n_split);
476
488
}
477
489
478
- std::vector<char > split_path (llama_path_max (), 0 );
479
490
for (idx = 1 ; idx < n_split; idx++) {
480
- llama_split_path (split_path. data (), split_path. size (), split_prefix. data (), idx, n_split );
491
+ const char * fname_split = splits[idx]. c_str ( );
481
492
482
493
struct gguf_init_params split_params = {
483
494
/* .no_alloc = */ true ,
484
495
/* .ctx = */ &ctx,
485
496
};
486
- gguf_context_ptr ctx_gguf { gguf_init_from_file (split_path. data () , split_params) };
497
+ gguf_context_ptr ctx_gguf { gguf_init_from_file (fname_split , split_params) };
487
498
if (!ctx_gguf) {
488
- throw std::runtime_error (format (" %s: failed to load GGUF split from %s\n " , __func__, split_path.data ()));
499
+ throw std::runtime_error (format (" %s: failed to load GGUF split from %s\n " , __func__, fname_split));
500
+ }
501
+
502
+ // check idx
503
+ {
504
+ const int kid = gguf_find_key (ctx_gguf.get (), kv_split_no.c_str ());
505
+ if (kid < 0 ) {
506
+ throw std::runtime_error (format (" missing key %s in GGUF split %s" , kv_split_no.c_str (), fname_split));
507
+ }
508
+ int idx_gguf = gguf_get_val_u16 (ctx_gguf.get (), kid);
509
+ if (idx_gguf != idx) {
510
+ throw std::runtime_error (format (" invalid split file idx: %d (file: %s), expected %d" , idx_gguf, fname_split, idx));
511
+ }
489
512
}
490
513
491
- files.emplace_back (new llama_file (split_path. data () , " rb" ));
514
+ files.emplace_back (new llama_file (fname_split , " rb" ));
492
515
contexts.emplace_back (ctx);
493
516
494
517
// Save tensors data offset info of the shard.
@@ -1070,3 +1093,28 @@ void llama_model_loader::print_info() const {
1070
1093
LLAMA_LOG_INFO (" %s: file size = %.2f GiB (%.2f BPW) \n " , __func__, n_bytes/1024.0 /1024.0 /1024.0 , n_bytes*8.0 /n_elements);
1071
1094
}
1072
1095
}
1096
+
1097
+ std::vector<std::string> llama_get_list_splits (const std::string & path, const int n_split) {
1098
+ std::vector<std::string> paths;
1099
+ std::string split_prefix;
1100
+ std::vector<char > buf (llama_path_max (), 0 );
1101
+
1102
+ // brute force to find the split prefix
1103
+ for (int idx = 0 ; idx < n_split; ++idx) {
1104
+ int ret = llama_split_prefix (buf.data (), buf.size (), path.c_str (), idx, n_split);
1105
+ if (ret) {
1106
+ split_prefix = std::string (buf.data (), ret);
1107
+ }
1108
+ }
1109
+
1110
+ if (split_prefix.empty ()) {
1111
+ throw std::runtime_error (format (" invalid split file: %s" , path.c_str ()));
1112
+ }
1113
+
1114
+ for (int idx = 0 ; idx < n_split; ++idx) {
1115
+ int ret = llama_split_path (buf.data (), buf.size (), split_prefix.c_str (), idx, n_split);
1116
+ paths.push_back (std::string (buf.data (), ret));
1117
+ }
1118
+
1119
+ return paths;
1120
+ }
0 commit comments