Skip to content

Commit 6562e5a

Browse files
authored
context : allow cache-less context for embeddings (#13108)
* context : allow cache-less context for embeddings ggml-ci * context : enable reranking with encode() ggml-ci * context : encode() clears embd_seq ggml-ci * examples : use llama_encode() when appropriate ggml-ci * models : nomic bert moe does not require KV cache * llama : update comments for llama_decode/llama_encode ggml-ci * context : update warning log [no ci]
1 parent 51fb96b commit 6562e5a

File tree

5 files changed

+47
-23
lines changed

5 files changed

+47
-23
lines changed

examples/embedding/embedding.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
3535

3636
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
3737
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
38-
const struct llama_model * model = llama_get_model(ctx);
3938

4039
// clear previous kv_cache values (irrelevant for embeddings)
4140
llama_kv_self_clear(ctx);
4241

4342
// run model
4443
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
45-
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
46-
// encoder-only model
47-
if (llama_encode(ctx, batch) < 0) {
48-
LOG_ERR("%s : failed to encode\n", __func__);
49-
}
50-
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
51-
// decoder-only model
52-
if (llama_decode(ctx, batch) < 0) {
53-
LOG_ERR("%s : failed to decode\n", __func__);
54-
}
44+
if (llama_encode(ctx, batch) < 0) {
45+
LOG_ERR("%s : failed to encode\n", __func__);
5546
}
5647

5748
for (int i = 0; i < batch.n_tokens; i++) {

include/llama.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -922,14 +922,19 @@ extern "C" {
922922
// Frees a batch of tokens allocated with llama_batch_init()
923923
LLAMA_API void llama_batch_free(struct llama_batch batch);
924924

925-
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
926-
// Stores the encoder output internally for later use by the decoder cross-attention layers.
925+
// Process a batch of tokens.
926+
// In contrast to llama_decode() - this call does not use KV cache.
927+
// For encode-decoder contexts, processes the batch using the encoder.
928+
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
927929
// 0 - success
928930
// < 0 - error. the KV cache state is restored to the state before this call
929931
LLAMA_API int32_t llama_encode(
930932
struct llama_context * ctx,
931933
struct llama_batch batch);
932934

935+
// Process a batch of tokens.
936+
// Requires KV cache.
937+
// For encode-decoder contexts, processes the batch using the decoder.
933938
// Positive return values does not mean a fatal error, but rather a warning.
934939
// 0 - success
935940
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)

src/llama-context.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ llama_context::llama_context(
251251
}
252252

253253
// reserve worst-case graph
254-
if (!hparams.vocab_only) {
254+
if (!hparams.vocab_only && memory) {
255255
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
256256
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
257257

@@ -700,6 +700,8 @@ int llama_context::encode(llama_batch & inp_batch) {
700700
t_compute_start_us = ggml_time_us();
701701
}
702702

703+
embd_seq.clear();
704+
703705
n_queued_tokens += n_tokens;
704706

705707
const int64_t n_embd = hparams.n_embd;
@@ -761,12 +763,12 @@ int llama_context::encode(llama_batch & inp_batch) {
761763
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
762764
GGML_ASSERT(backend_embd != nullptr);
763765

764-
GGML_ASSERT(embd != nullptr);
765-
766766
switch (cparams.pooling_type) {
767767
case LLAMA_POOLING_TYPE_NONE:
768768
{
769769
// extract token embeddings
770+
GGML_ASSERT(embd != nullptr);
771+
770772
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
771773
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
772774
} break;
@@ -791,11 +793,18 @@ int llama_context::encode(llama_batch & inp_batch) {
791793
} break;
792794
case LLAMA_POOLING_TYPE_RANK:
793795
{
794-
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
795-
// wait for an encoder model that requires this pooling type in order to test it
796-
// https://github.com/ggerganov/llama.cpp/pull/9510
797-
GGML_ABORT("RANK pooling not implemented yet");
798-
}
796+
// extract the rerank score - a single float per sequence
797+
auto & embd_seq_out = embd_seq;
798+
799+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
800+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
801+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
802+
continue;
803+
}
804+
embd_seq_out[seq_id].resize(1);
805+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
806+
}
807+
} break;
799808
case LLAMA_POOLING_TYPE_UNSPECIFIED:
800809
{
801810
GGML_ABORT("unknown pooling type");
@@ -833,6 +842,11 @@ int llama_context::encode(llama_batch & inp_batch) {
833842
}
834843

835844
int llama_context::decode(llama_batch & inp_batch) {
845+
if (!memory) {
846+
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
847+
return encode(inp_batch);
848+
}
849+
836850
if (inp_batch.n_tokens == 0) {
837851
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
838852
return -1;

src/llama-model.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12852,6 +12852,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1285212852
llama_memory_i * res;
1285312853

1285412854
switch (arch) {
12855+
case LLM_ARCH_BERT:
12856+
case LLM_ARCH_JINA_BERT_V2:
12857+
case LLM_ARCH_NOMIC_BERT:
12858+
case LLM_ARCH_NOMIC_BERT_MOE:
12859+
{
12860+
res = nullptr;
12861+
} break;
1285512862
case LLM_ARCH_MAMBA:
1285612863
case LLM_ARCH_RWKV6:
1285712864
case LLM_ARCH_RWKV6QWEN2:

tools/server/server.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3214,7 +3214,14 @@ struct server_context {
32143214
batch.logits + i,
32153215
};
32163216

3217-
const int ret = llama_decode(ctx, batch_view);
3217+
int ret = 0;
3218+
3219+
if (params_base.embedding || params_base.reranking) {
3220+
ret = llama_encode(ctx, batch_view);
3221+
} else {
3222+
ret = llama_decode(ctx, batch_view);
3223+
}
3224+
32183225
metrics.on_decoded(slots);
32193226

32203227
if (ret != 0) {
@@ -3943,7 +3950,7 @@ int main(int argc, char ** argv) {
39433950
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
39443951
server_task_type type,
39453952
json & data,
3946-
std::function<bool()> is_connection_closed,
3953+
const std::function<bool()> & is_connection_closed,
39473954
httplib::Response & res,
39483955
oaicompat_type oaicompat) {
39493956
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);

0 commit comments

Comments
 (0)