Skip to content

Commit e3ac833

Browse files
committed
using abort_callback from ggml to stop llama computation
1 parent 67fd331 commit e3ac833

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

ggml-backend.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ extern "C" {
8080
//
8181
// CPU backend
8282
//
83-
8483
GGML_API ggml_backend_t ggml_backend_cpu_init(void);
8584

8685
GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);

llama.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,6 +1948,9 @@ struct llama_context {
19481948
std::vector<uint8_t> buf_compute_meta;
19491949
ggml_backend_sched_t sched = nullptr;
19501950

1951+
ggml_abort_callback abort_callback = nullptr;
1952+
void * abort_callback_data = nullptr;
1953+
19511954
// input tensors
19521955
ggml_backend_buffer_t buf_input = nullptr;
19531956
ggml_context * ctx_input = nullptr;
@@ -7847,6 +7850,7 @@ static void llama_graph_compute(
78477850

78487851
if (lctx.backend_cpu != nullptr) {
78497852
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
7853+
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
78507854
}
78517855

78527856
ggml_backend_sched_graph_compute(lctx.sched, gf);
@@ -11644,6 +11648,8 @@ struct llama_context_params llama_context_default_params() {
1164411648
/*.embedding =*/ false,
1164511649
/*.offload_kqv =*/ true,
1164611650
/*.do_pooling =*/ true,
11651+
/*.abort_callback =*/ nullptr,
11652+
/*.abort_callback_data =*/ nullptr,
1164711653
};
1164811654

1164911655
return result;
@@ -11835,8 +11841,11 @@ struct llama_context * llama_new_context_with_model(
1183511841
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
1183611842
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
1183711843

11838-
ctx->rng = std::mt19937(params.seed);
11839-
ctx->logits_all = params.logits_all;
11844+
ctx->abort_callback = params.abort_callback;
11845+
ctx->abort_callback_data = params.abort_callback_data;
11846+
11847+
ctx->rng = std::mt19937(params.seed);
11848+
ctx->logits_all = params.logits_all;
1184011849

1184111850
const ggml_type type_k = params.type_k;
1184211851
const ggml_type type_v = params.type_v;
@@ -12809,6 +12818,11 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_
1280912818
ctx->cparams.n_threads_batch = n_threads_batch;
1281012819
}
1281112820

12821+
void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
12822+
ctx->abort_callback = abort_callback;
12823+
ctx->abort_callback_data = abort_callback_data;
12824+
}
12825+
1281212826
struct llama_batch llama_batch_get_one(
1281312827
llama_token * tokens,
1281412828
int32_t n_tokens,

llama.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ extern "C" {
256256
bool embedding; // embedding mode only
257257
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
258258
bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
259+
260+
ggml_abort_callback abort_callback;
261+
void * abort_callback_data;
259262
};
260263

261264
// model quantization parameters
@@ -661,6 +664,9 @@ extern "C" {
661664
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
662665
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
663666

667+
// Set abort callback
668+
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
669+
664670
// Token logits obtained from the last call to llama_eval()
665671
// The logits for the last token are stored in the last row
666672
// Logits for which llama_batch.logits[i] == 0 are undefined

0 commit comments

Comments
 (0)