Skip to content

Commit 9abe2e4

Browse files
committed
llama : Add ability to cancel model load
Updated llama_progress_callback so that if it returns false, the model loading is aborted.
1 parent 55e87c3 commit 9abe2e4

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

llama.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,7 +2297,8 @@ struct llama_model_loader {
22972297
}
22982298
}
22992299

2300-
void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) {
2300+
// Returns false if cancelled by progress_callback
2301+
bool load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) {
23012302
size_t size_data = 0;
23022303
size_t size_lock = 0;
23032304
size_t size_pref = 0; // prefetch
@@ -2323,7 +2324,9 @@ struct llama_model_loader {
23232324
GGML_ASSERT(cur); // unused tensors should have been caught by load_data already
23242325

23252326
if (progress_callback) {
2326-
progress_callback((float) done_size / size_data, progress_callback_user_data);
2327+
if (!progress_callback((float) done_size / size_data, progress_callback_user_data)) {
2328+
return false;
2329+
}
23272330
}
23282331

23292332
// allocate temp buffer if not using mmap
@@ -2371,6 +2374,7 @@ struct llama_model_loader {
23712374

23722375
done_size += ggml_nbytes(cur);
23732376
}
2377+
return true;
23742378
}
23752379
};
23762380

@@ -2937,7 +2941,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
29372941
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
29382942
}
29392943

2940-
static void llm_load_tensors(
2944+
// Returns false if cancelled by progress_callback
2945+
static bool llm_load_tensors(
29412946
llama_model_loader & ml,
29422947
llama_model & model,
29432948
int n_gpu_layers,
@@ -2948,6 +2953,8 @@ static void llm_load_tensors(
29482953
void * progress_callback_user_data) {
29492954
model.t_start_us = ggml_time_us();
29502955

2956+
bool ok = true; // if false, model load was cancelled
2957+
29512958
auto & ctx = model.ctx;
29522959
auto & hparams = model.hparams;
29532960

@@ -3678,20 +3685,23 @@ static void llm_load_tensors(
36783685
}
36793686
#endif
36803687

3681-
ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL);
3682-
3688+
ok = ok && ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL);
36833689
if (progress_callback) {
3684-
progress_callback(1.0f, progress_callback_user_data);
3690+
// Even though the model is done loading, we still honor
3691+
// cancellation since we need to free allocations.
3692+
ok = ok && progress_callback(1.0f, progress_callback_user_data);
36853693
}
36863694

36873695
model.mapping = std::move(ml.mapping);
36883696

36893697
// loading time will be recalculate after the first eval, so
36903698
// we take page faults deferred by mmap() into consideration
36913699
model.t_load_us = ggml_time_us() - model.t_start_us;
3700+
return ok;
36923701
}
36933702

3694-
static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
3703+
// Returns -1 on error, -2 on cancellation via llama_progress_callback
3704+
static int llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
36953705
try {
36963706
llama_model_loader ml(fname, params.use_mmap, params.kv_overrides);
36973707

@@ -3712,16 +3722,18 @@ static bool llama_model_load(const std::string & fname, llama_model & model, con
37123722
return true;
37133723
}
37143724

3715-
llm_load_tensors(
3725+
if (!llm_load_tensors(
37163726
ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock,
37173727
params.progress_callback, params.progress_callback_user_data
3718-
);
3728+
)) {
3729+
return -2;
3730+
}
37193731
} catch (const std::exception & err) {
37203732
LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
3721-
return false;
3733+
return -1;
37223734
}
37233735

3724-
return true;
3736+
return 0;
37253737
}
37263738

37273739
//
@@ -9017,11 +9029,18 @@ struct llama_model * llama_load_model_from_file(
90179029
LLAMA_LOG_INFO("\n");
90189030
}
90199031
}
9032+
return true;
90209033
};
90219034
}
90229035

9023-
if (!llama_model_load(path_model, *model, params)) {
9024-
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
9036+
int status = llama_model_load(path_model, *model, params);
9037+
GGML_ASSERT(status <= 0);
9038+
if (status < 0) {
9039+
if (status == -1) {
9040+
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
9041+
} else if (status == -2) {
9042+
LLAMA_LOG_INFO("%s, cancelled model load\n", __func__);
9043+
}
90259044
delete model;
90269045
return nullptr;
90279046
}

llama.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ extern "C" {
126126
bool sorted;
127127
} llama_token_data_array;
128128

129-
typedef void (*llama_progress_callback)(float progress, void *ctx);
129+
typedef bool (*llama_progress_callback)(float progress, void *ctx);
130130

131131
// Input data for llama_decode
132132
// A llama_batch object can contain input about one or many sequences
@@ -179,7 +179,9 @@ extern "C" {
179179
int32_t main_gpu; // the GPU that is used for scratch and small tensors
180180
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
181181

182-
// called with a progress value between 0 and 1, pass NULL to disable
182+
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
183+
// If the provided progress_callback returns true, model loading continues.
184+
// If it returns false, model loading is immediately aborted.
183185
llama_progress_callback progress_callback;
184186

185187
// context pointer passed to the progress callback

0 commit comments

Comments
 (0)