@@ -2297,7 +2297,8 @@ struct llama_model_loader {
2297
2297
}
2298
2298
}
2299
2299
2300
- void load_all_data (struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) {
2300
+ // Returns false if cancelled by progress_callback
2301
+ bool load_all_data (struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) {
2301
2302
size_t size_data = 0 ;
2302
2303
size_t size_lock = 0 ;
2303
2304
size_t size_pref = 0 ; // prefetch
@@ -2323,7 +2324,9 @@ struct llama_model_loader {
2323
2324
GGML_ASSERT (cur); // unused tensors should have been caught by load_data already
2324
2325
2325
2326
if (progress_callback) {
2326
- progress_callback ((float ) done_size / size_data, progress_callback_user_data);
2327
+ if (!progress_callback ((float ) done_size / size_data, progress_callback_user_data)) {
2328
+ return false ;
2329
+ }
2327
2330
}
2328
2331
2329
2332
// allocate temp buffer if not using mmap
@@ -2371,6 +2374,7 @@ struct llama_model_loader {
2371
2374
2372
2375
done_size += ggml_nbytes (cur);
2373
2376
}
2377
+ return true ;
2374
2378
}
2375
2379
};
2376
2380
@@ -2937,7 +2941,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
2937
2941
if (vocab.linefeed_id != -1 ) { LLAMA_LOG_INFO ( " %s: LF token = %d '%s'\n " , __func__, vocab.linefeed_id , vocab.id_to_token [vocab.linefeed_id ].text .c_str () ); }
2938
2942
}
2939
2943
2940
- static void llm_load_tensors (
2944
+ // Returns false if cancelled by progress_callback
2945
+ static bool llm_load_tensors (
2941
2946
llama_model_loader & ml,
2942
2947
llama_model & model,
2943
2948
int n_gpu_layers,
@@ -2948,6 +2953,8 @@ static void llm_load_tensors(
2948
2953
void * progress_callback_user_data) {
2949
2954
model.t_start_us = ggml_time_us ();
2950
2955
2956
+ bool ok = true ; // if false, model load was cancelled
2957
+
2951
2958
auto & ctx = model.ctx ;
2952
2959
auto & hparams = model.hparams ;
2953
2960
@@ -3678,20 +3685,23 @@ static void llm_load_tensors(
3678
3685
}
3679
3686
#endif
3680
3687
3681
- ml.load_all_data (ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL );
3682
-
3688
+ ok = ok && ml.load_all_data (ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL );
3683
3689
if (progress_callback) {
3684
- progress_callback (1 .0f , progress_callback_user_data);
3690
+ // Even though the model is done loading, we still honor
3691
+ // cancellation since we need to free allocations.
3692
+ ok = ok && progress_callback (1 .0f , progress_callback_user_data);
3685
3693
}
3686
3694
3687
3695
model.mapping = std::move (ml.mapping );
3688
3696
3689
3697
// loading time will be recalculate after the first eval, so
3690
3698
// we take page faults deferred by mmap() into consideration
3691
3699
model.t_load_us = ggml_time_us () - model.t_start_us ;
3700
+ return ok;
3692
3701
}
3693
3702
3694
- static bool llama_model_load (const std::string & fname, llama_model & model, const llama_model_params & params) {
3703
+ // Returns -1 on error, -2 on cancellation via llama_progress_callback
3704
+ static int llama_model_load (const std::string & fname, llama_model & model, const llama_model_params & params) {
3695
3705
try {
3696
3706
llama_model_loader ml (fname, params.use_mmap , params.kv_overrides );
3697
3707
@@ -3712,16 +3722,18 @@ static bool llama_model_load(const std::string & fname, llama_model & model, con
3712
3722
return true ;
3713
3723
}
3714
3724
3715
- llm_load_tensors (
3725
+ if (! llm_load_tensors (
3716
3726
ml, model, params.n_gpu_layers , params.main_gpu , params.tensor_split , params.use_mlock ,
3717
3727
params.progress_callback , params.progress_callback_user_data
3718
- );
3728
+ )) {
3729
+ return -2 ;
3730
+ }
3719
3731
} catch (const std::exception & err) {
3720
3732
LLAMA_LOG_ERROR (" error loading model: %s\n " , err.what ());
3721
- return false ;
3733
+ return - 1 ;
3722
3734
}
3723
3735
3724
- return true ;
3736
+ return 0 ;
3725
3737
}
3726
3738
3727
3739
//
@@ -9017,11 +9029,18 @@ struct llama_model * llama_load_model_from_file(
9017
9029
LLAMA_LOG_INFO (" \n " );
9018
9030
}
9019
9031
}
9032
+ return true ;
9020
9033
};
9021
9034
}
9022
9035
9023
- if (!llama_model_load (path_model, *model, params)) {
9024
- LLAMA_LOG_ERROR (" %s: failed to load model\n " , __func__);
9036
+ int status = llama_model_load (path_model, *model, params);
9037
+ GGML_ASSERT (status <= 0 );
9038
+ if (status < 0 ) {
9039
+ if (status == -1 ) {
9040
+ LLAMA_LOG_ERROR (" %s: failed to load model\n " , __func__);
9041
+ } else if (status == -2 ) {
9042
+ LLAMA_LOG_INFO (" %s, cancelled model load\n " , __func__);
9043
+ }
9025
9044
delete model;
9026
9045
return nullptr ;
9027
9046
}
0 commit comments