Skip to content

Commit 1782462

Browse files
committed
llama : add llama_model_load_from_splits
1 parent 1d85043 commit 1782462

File tree

5 files changed

+117
-26
lines changed

5 files changed

+117
-26
lines changed

include/llama.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,20 @@ extern "C" {
418418
struct llama_model_params params),
419419
"use llama_model_load_from_file instead");
420420

421+
// Load the model from a file
422+
// If the file is split into multiple parts, the file name must follow this pattern: <name>-%05d-of-%05d.gguf
423+
// If the split file name does not follow this pattern, use llama_model_load_from_splits
421424
LLAMA_API struct llama_model * llama_model_load_from_file(
422425
const char * path_model,
423426
struct llama_model_params params);
424427

428+
// Load the model from multiple splits (support custom naming scheme)
429+
// The paths must be in the correct order
430+
LLAMA_API struct llama_model * llama_model_load_from_splits(
431+
const char ** paths,
432+
size_t n_paths,
433+
struct llama_model_params params);
434+
425435
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
426436
"use llama_model_free instead");
427437

src/llama-model-loader.cpp

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,12 @@ namespace GGUFMeta {
413413
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
414414
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
415415

416-
llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
416+
llama_model_loader::llama_model_loader(
417+
const std::string & fname,
418+
std::vector<std::string> & splits,
419+
bool use_mmap,
420+
bool check_tensors,
421+
const struct llama_model_kv_override * param_overrides_p) {
417422
int trace = 0;
418423
if (getenv("LLAMA_TRACE")) {
419424
trace = atoi(getenv("LLAMA_TRACE"));
@@ -425,6 +430,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
425430
}
426431
}
427432

433+
// Load the main GGUF
428434
struct ggml_context * ctx = NULL;
429435
struct gguf_init_params params = {
430436
/*.no_alloc = */ true,
@@ -460,35 +466,52 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
460466

461467
// Load additional GGML contexts
462468
if (n_split > 1) {
463-
uint16_t idx = 0;
464-
get_key(llm_kv(LLM_KV_SPLIT_NO), idx);
465-
if (idx != 0) {
466-
throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx));
469+
// generate list of splits if needed
470+
if (splits.empty()) {
471+
splits = llama_get_list_splits(fname, n_split);
467472
}
468473

469-
std::vector<char> split_prefix(llama_path_max(), 0);
470-
if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) {
471-
throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
474+
// in case user give a custom list of splits, check if it matches the expected number
475+
if (n_split != (uint16_t)splits.size()) {
476+
throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split));
477+
}
478+
479+
uint16_t idx = 0;
480+
const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO);
481+
get_key(kv_split_no, idx);
482+
if (idx != 0) {
483+
throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str()));
472484
}
473485

474486
if (trace > 0) {
475487
LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
476488
}
477489

478-
std::vector<char> split_path(llama_path_max(), 0);
479490
for (idx = 1; idx < n_split; idx++) {
480-
llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split);
491+
const char * fname_split = splits[idx].c_str();
481492

482493
struct gguf_init_params split_params = {
483494
/*.no_alloc = */ true,
484495
/*.ctx = */ &ctx,
485496
};
486-
gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) };
497+
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
487498
if (!ctx_gguf) {
488-
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data()));
499+
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split));
500+
}
501+
502+
// check idx
503+
{
504+
const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str());
505+
if (kid < 0) {
506+
throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split));
507+
}
508+
int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid);
509+
if (idx_gguf != idx) {
510+
throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx));
511+
}
489512
}
490513

491-
files.emplace_back(new llama_file(split_path.data(), "rb"));
514+
files.emplace_back(new llama_file(fname_split, "rb"));
492515
contexts.emplace_back(ctx);
493516

494517
// Save tensors data offset info of the shard.
@@ -1070,3 +1093,28 @@ void llama_model_loader::print_info() const {
10701093
LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements);
10711094
}
10721095
}
1096+
1097+
std::vector<std::string> llama_get_list_splits(const std::string & path, const int n_split) {
1098+
std::vector<std::string> paths;
1099+
std::string split_prefix;
1100+
std::vector<char> buf(llama_path_max(), 0);
1101+
1102+
// brute force to find the split prefix
1103+
for (int idx = 0; idx < n_split; ++idx) {
1104+
int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split);
1105+
if (ret) {
1106+
split_prefix = std::string(buf.data(), ret);
1107+
}
1108+
}
1109+
1110+
if (split_prefix.empty()) {
1111+
throw std::runtime_error(format("invalid split file: %s", path.c_str()));
1112+
}
1113+
1114+
for (int idx = 0; idx < n_split; ++idx) {
1115+
int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split);
1116+
paths.push_back(std::string(buf.data(), ret));
1117+
}
1118+
1119+
return paths;
1120+
}

src/llama-model-loader.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ struct llama_model_loader {
9090
size_t size_data = 0;
9191
std::vector<std::pair<size_t, size_t>> mmaps_used;
9292

93-
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p);
93+
llama_model_loader(
94+
const std::string & fname,
95+
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
96+
bool use_mmap,
97+
bool check_tensors,
98+
const struct llama_model_kv_override * param_overrides_p);
9499

95100
template<typename T>
96101
typename std::enable_if<std::is_integral<T>::value, bool>::type
@@ -160,3 +165,7 @@ struct llama_model_loader {
160165

161166
void print_info() const;
162167
};
168+
169+
// return a list of splits for a given path
170+
// for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
171+
std::vector<std::string> llama_get_list_splits(const std::string & path, const int n_split);

src/llama-quant.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
526526
kv_overrides = v->data();
527527
}
528528

529-
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
529+
std::vector<std::string> splits = {};
530+
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
530531
ml.init_mappings(false); // no prefetching
531532

532533
llama_model model(llama_model_default_params());

src/llama.cpp

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
#endif
3232

3333
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
34-
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
34+
static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
3535
// loading time will be recalculated after the first eval, so
3636
// we take page faults deferred by mmap() into consideration
3737
model.t_load_us = 0;
@@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
4040
model.t_start_us = tm.t_start_us;
4141

4242
try {
43-
llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
43+
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
4444

4545
ml.print_info();
4646

@@ -9374,14 +9374,9 @@ int64_t llama_time_us(void) {
93749374
return ggml_time_us();
93759375
}
93769376

9377-
struct llama_model * llama_load_model_from_file(
9378-
const char * path_model,
9379-
struct llama_model_params params) {
9380-
return llama_model_load_from_file(path_model, params);
9381-
}
9382-
9383-
struct llama_model * llama_model_load_from_file(
9384-
const char * path_model,
9377+
static struct llama_model * llama_model_load_from_file_impl(
9378+
const std::string & path_model,
9379+
std::vector<std::string> & splits,
93859380
struct llama_model_params params) {
93869381
ggml_time_init();
93879382

@@ -9485,7 +9480,7 @@ struct llama_model * llama_model_load_from_file(
94859480
LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
94869481
}
94879482

9488-
const int status = llama_model_load(path_model, *model, params);
9483+
const int status = llama_model_load(path_model, splits, *model, params);
94899484
GGML_ASSERT(status <= 0);
94909485
if (status < 0) {
94919486
if (status == -1) {
@@ -9501,6 +9496,34 @@ struct llama_model * llama_model_load_from_file(
95019496
return model;
95029497
}
95039498

9499+
struct llama_model * llama_load_model_from_file(
9500+
const char * path_model,
9501+
struct llama_model_params params) {
9502+
return llama_model_load_from_file(path_model, params);
9503+
}
9504+
9505+
struct llama_model * llama_model_load_from_file(
9506+
const char * path_model,
9507+
struct llama_model_params params) {
9508+
std::vector<std::string> splits = {};
9509+
return llama_model_load_from_file_impl(path_model, splits, params);
9510+
}
9511+
9512+
struct llama_model * llama_model_load_from_splits(
9513+
const char ** paths,
9514+
size_t n_paths,
9515+
struct llama_model_params params) {
9516+
std::vector<std::string> splits;
9517+
if (n_paths == 0) {
9518+
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
9519+
return nullptr;
9520+
}
9521+
for (size_t i = 0; i < n_paths; ++i) {
9522+
splits.push_back(paths[i]);
9523+
}
9524+
return llama_model_load_from_file_impl(splits.front(), splits, params);
9525+
}
9526+
95049527
struct llama_context * llama_init_from_model(
95059528
struct llama_model * model,
95069529
struct llama_context_params params) {

0 commit comments

Comments
 (0)