Skip to content

Commit d3e64b9

Browse files
authored
llama : rework embeddings logic (#14208)
* llama : rework embeddings logic ggml-ci * cont : fix rerank ggml-ci * cont : engrish [no ci] * cont : fix rerank ggml-ci * server : support both embeddings and completions with single model ggml-ci * cont : avoid embeddings_org ggml-ci
1 parent 3ba0d84 commit d3e64b9

16 files changed

+159
-114
lines changed

common/arg.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -988,10 +988,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
988988
params.tensor_buft_overrides.push_back({nullptr, nullptr});
989989
}
990990

991-
if (params.reranking && params.embedding) {
992-
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
993-
}
994-
995991
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
996992
throw std::runtime_error(string_format(
997993
"error: the supplied chat template is not supported: %s%s\n",
@@ -2747,9 +2743,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27472743
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
27482744
add_opt(common_arg(
27492745
{"--reranking", "--rerank"},
2750-
string_format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
2746+
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
27512747
[](common_params & params) {
2752-
params.reranking = true;
2748+
params.embedding = true;
2749+
params.pooling_type = LLAMA_POOLING_TYPE_RANK;
27532750
}
27542751
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
27552752
add_opt(common_arg(

common/common.cpp

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -897,34 +897,6 @@ struct common_init_result common_init_from_params(common_params & params) {
897897

898898
const llama_vocab * vocab = llama_model_get_vocab(model);
899899

900-
if (params.reranking) {
901-
bool ok = true;
902-
903-
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
904-
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
905-
ok = false;
906-
}
907-
908-
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
909-
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
910-
911-
if (!has_eos && !has_sep) {
912-
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
913-
ok = false;
914-
} else if (!has_eos) {
915-
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
916-
} else if (!has_sep) {
917-
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
918-
ok = false;
919-
}
920-
921-
if (!ok) {
922-
llama_model_free(model);
923-
924-
return iparams;
925-
}
926-
}
927-
928900
auto cparams = common_context_params_to_llama(params);
929901

930902
llama_context * lctx = llama_init_from_model(model, cparams);
@@ -966,6 +938,35 @@ struct common_init_result common_init_from_params(common_params & params) {
966938
}
967939
}
968940

941+
if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) {
942+
bool ok = true;
943+
944+
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
945+
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
946+
ok = false;
947+
}
948+
949+
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
950+
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
951+
952+
if (!has_eos && !has_sep) {
953+
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
954+
ok = false;
955+
} else if (!has_eos) {
956+
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
957+
} else if (!has_sep) {
958+
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
959+
ok = false;
960+
}
961+
962+
if (!ok) {
963+
llama_free(lctx);
964+
llama_model_free(model);
965+
966+
return iparams;
967+
}
968+
}
969+
969970
// load and optionally apply lora adapters
970971
for (auto & la : params.lora_adapters) {
971972
llama_adapter_lora_ptr lora;
@@ -1143,11 +1144,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11431144
cparams.op_offload = !params.no_op_offload;
11441145
cparams.swa_full = params.swa_full;
11451146

1146-
if (params.reranking) {
1147-
cparams.embeddings = true;
1148-
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
1149-
}
1150-
11511147
cparams.type_k = params.cache_type_k;
11521148
cparams.type_v = params.cache_type_v;
11531149

common/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,6 @@ struct common_params {
355355
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
356356
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
357357
std::string embd_sep = "\n"; // separator of embeddings
358-
bool reranking = false; // enable reranking support on server
359358

360359
// server params
361360
int32_t port = 8080; // server listens on this network port

examples/gritlm/gritlm.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4141

4242
// add input to batch (this increments n_tokens)
4343
for (int32_t j = 0; j < n_toks; j++) {
44-
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
44+
common_batch_add(batch, inputs[j], j, { 0 }, true);
4545
}
4646

4747
// clear previous kv_cache values (irrelevant for embeddings)
4848
llama_memory_clear(llama_get_memory(ctx), true);
49-
llama_set_embeddings(ctx, true);
5049
llama_set_causal_attn(ctx, false);
5150

5251
// run model
@@ -103,7 +102,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
103102
llama_token eos_token = llama_vocab_eos(vocab);
104103

105104
llama_memory_clear(llama_get_memory(ctx), true);
106-
llama_set_embeddings(ctx, false);
107105
llama_set_causal_attn(ctx, true);
108106

109107
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
@@ -166,6 +164,8 @@ int main(int argc, char * argv[]) {
166164
llama_model_params mparams = common_model_params_to_llama(params);
167165
llama_context_params cparams = common_context_params_to_llama(params);
168166

167+
cparams.embeddings = true;
168+
169169
llama_backend_init();
170170

171171
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
@@ -213,6 +213,8 @@ int main(int argc, char * argv[]) {
213213
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
214214
}
215215

216+
llama_set_embeddings(ctx, false);
217+
216218
// ### Generation ###
217219
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
218220
{

include/llama.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,19 @@ extern "C" {
254254
// - seq_id : the sequence to which the respective token belongs
255255
// (if set to NULL, the sequence ID will be assumed to be 0)
256256
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
257-
// (if set to NULL, only the logits for last token will be returned)
257+
// (if set to NULL:
258+
// - if embeddings: all tokens are output
259+
// - if not: only the last token is output
260+
// )
258261
//
259262
typedef struct llama_batch {
260263
int32_t n_tokens;
261264

262265
llama_token * token;
263266
float * embd;
264267
llama_pos * pos;
265-
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
266-
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
268+
int32_t * n_seq_id;
269+
llama_seq_id ** seq_id;
267270
int8_t * logits; // TODO: rename this to "output"
268271
} llama_batch;
269272

@@ -961,8 +964,7 @@ extern "C" {
961964
// Get the number of threads used for prompt and batch processing (multiple token).
962965
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
963966

964-
// Set whether the model is in embeddings mode or not
965-
// If true, embeddings will be returned but logits will not
967+
// Set whether the context outputs embeddings or not
966968
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
967969

968970
// Set whether to use causal attention or not

src/llama-batch.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ llama_batch_allocr::llama_batch_allocr() {
299299
bool llama_batch_allocr::init(
300300
const llama_batch & batch_inp,
301301
const llama_vocab & vocab,
302-
const llama_memory_i * memory) {
302+
const llama_memory_i * memory,
303+
bool embd_all) {
303304
clear();
304305

305306
batch = batch_inp;
@@ -378,10 +379,31 @@ bool llama_batch_allocr::init(
378379
}
379380

380381
if (!batch.logits) {
381-
// by default return the output only for the last token
382-
output.resize(batch.n_tokens);
383-
output[output.size() - 1] = true;
382+
if (embd_all) {
383+
// return the output for all tokens
384+
output.resize(batch.n_tokens, true);
385+
} else {
386+
// return the output only for the last token
387+
output.resize(batch.n_tokens, false);
388+
output[output.size() - 1] = true;
389+
}
390+
384391
batch.logits = output.data();
392+
} else if (embd_all) {
393+
bool warn = false;
394+
395+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
396+
if (batch.logits[i] == 0) {
397+
warn = true;
398+
}
399+
}
400+
401+
if (warn) {
402+
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
403+
404+
output.resize(batch.n_tokens, true);
405+
batch.logits = output.data();
406+
}
385407
}
386408

387409
//

src/llama-batch.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class llama_batch_allocr {
8888
bool init(
8989
const llama_batch & batch_inp,
9090
const llama_vocab & vocab,
91-
const llama_memory_i * memory);
91+
const llama_memory_i * memory,
92+
bool embd_all);
9293

9394
const llama_batch & get_batch() const;
9495

src/llama-context.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
728728
}
729729

730730
// note: during encode, we always pass the full sequence starting from pos = 0
731-
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
731+
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
732732
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733733
return -1;
734734
}
@@ -894,7 +894,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
894894
return -1;
895895
}
896896

897-
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
897+
// when computing embeddings, all tokens are output
898+
const bool embd_all = cparams.embeddings;
899+
900+
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
898901
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
899902
return -1;
900903
}
@@ -911,12 +914,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
911914

912915
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
913916

914-
// this indicates we are doing pooled embedding
915-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
916-
917917
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
918918

919-
if (embd_pooled) {
919+
if (embd_all) {
920920
// require that all tokens are output
921921
if (n_outputs_all != n_tokens_all) {
922922
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -945,7 +945,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
945945
llama_memory_state_ptr mstate;
946946

947947
while (true) {
948-
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
948+
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
949949
if (!mstate) {
950950
return -2;
951951
}
@@ -1058,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10581058
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
10591059
//}
10601060

1061-
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1061+
auto * t_logits = res->get_logits();
10621062
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
10631063

10641064
if (t_embd && res->get_embd_pooled()) {
@@ -1222,9 +1222,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12221222
const auto n_vocab = vocab.n_tokens();
12231223
const auto n_embd = hparams.n_embd;
12241224

1225-
// TODO: use a per-batch flag for logits presence instead
1226-
bool has_logits = !cparams.embeddings;
1227-
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1225+
bool has_logits = true;
1226+
bool has_embd = cparams.embeddings;
12281227

12291228
// TODO: hacky enc-dec support
12301229
if (model.arch == LLM_ARCH_T5) {
@@ -2044,14 +2043,11 @@ void llama_context::opt_epoch_iter(
20442043

20452044
n_queued_tokens += n_tokens_all;
20462045

2047-
// this indicates we are doing pooled embedding
2048-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2049-
20502046
embd_seq.clear();
20512047

20522048
uint32_t n_outputs_all = n_tokens_all;
20532049

2054-
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
2050+
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
20552051
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20562052
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20572053
break;

src/llama-kv-cache-recurrent.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,18 +359,16 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
359359
return result;
360360
}
361361

362-
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
363-
GGML_UNUSED(embd_pooled);
364-
362+
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
365363
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
366364

367365
std::vector<llama_ubatch> ubatches;
368366

369367
while (sbatch.n_tokens > 0) {
370368
llama_ubatch ubatch;
371369

372-
if (embd_pooled) {
373-
// Pooled embeddings cannot be split across ubatches (yet)
370+
if (embd_all) {
371+
// if all tokens are output, split by sequence
374372
ubatch = sbatch.split_seq(n_ubatch);
375373
} else {
376374
ubatch = sbatch.split_equal(n_ubatch);

src/llama-kv-cache-recurrent.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class llama_kv_cache_recurrent : public llama_memory_i {
3232
llama_memory_state_ptr init_batch(
3333
const llama_batch & batch,
3434
uint32_t n_ubatch,
35-
bool embd_pooled) override;
35+
bool embd_all) override;
3636

3737
llama_memory_state_ptr init_full() override;
3838

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9595
return kv_swa->seq_pos_max(seq_id);
9696
}
9797

98-
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
99-
GGML_UNUSED(embd_pooled);
98+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
99+
GGML_UNUSED(embd_all);
100100

101101
// first try simple split
102102
do {

src/llama-kv-cache-unified-iswa.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
3434
llama_memory_state_ptr init_batch(
3535
const llama_batch & batch,
3636
uint32_t n_ubatch,
37-
bool embd_pooled) override;
37+
bool embd_all) override;
3838

3939
llama_memory_state_ptr init_full() override;
4040

src/llama-kv-cache-unified.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
310310
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311311
const llama_batch & batch,
312312
uint32_t n_ubatch,
313-
bool embd_pooled) {
314-
GGML_UNUSED(embd_pooled);
313+
bool embd_all) {
314+
GGML_UNUSED(embd_all);
315315

316316
do {
317317
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);

src/llama-kv-cache-unified.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class llama_kv_cache_unified : public llama_memory_i {
5959
llama_memory_state_ptr init_batch(
6060
const llama_batch & batch,
6161
uint32_t n_ubatch,
62-
bool embd_pooled) override;
62+
bool embd_all) override;
6363

6464
llama_memory_state_ptr init_full() override;
6565

0 commit comments

Comments
 (0)