Skip to content

Commit 1549493

Browse files
committed
batched embedding: pool outputs by sequence id. updated embedding example
1 parent dbd8828 commit 1549493

File tree

6 files changed

+158
-44
lines changed

6 files changed

+158
-44
lines changed

convert-hf-to-gguf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,7 @@ def set_gguf_parameters(self):
16481648
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
16491649
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
16501650
self.gguf_writer.add_causal_attention(False)
1651+
self.gguf_writer.add_pooling_layer(True)
16511652
self.gguf_writer.add_file_type(self.ftype)
16521653

16531654
def set_vocab(self):

examples/embedding/embedding.cpp

Lines changed: 110 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,59 @@
77
#pragma warning(disable: 4244 4267) // possible loss of data
88
#endif
99

10+
static std::vector<std::string> split_lines(const std::string & s) {
11+
std::string line;
12+
std::vector<std::string> lines;
13+
std::stringstream ss(s);
14+
while (std::getline(ss, line)) {
15+
lines.push_back(line);
16+
}
17+
return lines;
18+
}
19+
20+
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
21+
const uint64_t n_tokens = tokens.size();
22+
int n_past = batch.n_tokens;
23+
batch.n_tokens += n_tokens;
24+
for (uint64_t i = 0; i < n_tokens; i++) {
25+
uint64_t j = n_past + i;
26+
batch.token[j] = tokens[i];
27+
batch.pos[j] = i;
28+
batch.n_seq_id[j] = 1;
29+
batch.seq_id[j][0] = seq_id;
30+
batch.logits[j] = 0;
31+
}
32+
}
33+
34+
static void normalize(float * vec, float * out, int n) {
35+
float norm = 0;
36+
for (int i = 0; i < n; i++) {
37+
norm += vec[i] * vec[i];
38+
}
39+
norm = sqrt(norm);
40+
for (int i = 0; i < n; i++) {
41+
out[i] = vec[i] / norm;
42+
}
43+
}
44+
45+
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
46+
// clear previous kv_cache values (irrelevant for embeddings)
47+
llama_kv_cache_clear(ctx);
48+
49+
// run model
50+
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
51+
if (llama_decode(ctx, batch) < 0) {
52+
fprintf(stderr, "%s : failed to decode\n", __func__);
53+
}
54+
55+
// normalize on copy
56+
for (int k = 0; k < n_seq; k++) {
57+
float * emb = llama_get_embeddings_ith(ctx, k);
58+
float * out = output + k * n_embd;
59+
normalize(emb, out, n_embd);
60+
}
61+
}
62+
1063
int main(int argc, char ** argv) {
1164
gpt_params params;
1265

@@ -55,59 +108,81 @@ int main(int argc, char ** argv) {
55108
fprintf(stderr, "%s\n", get_system_info(params).c_str());
56109
}
57110

58-
int n_past = 0;
111+
// split the prompt into lines
112+
std::vector<std::string> prompts = split_lines(params.prompt);
59113

60-
// tokenize the prompt
61-
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
114+
// max batch size
115+
const uint64_t n_batch = params.n_batch;
116+
GGML_ASSERT(params.n_batch == params.n_ctx);
62117

63-
if (params.verbose_prompt) {
64-
fprintf(stderr, "\n");
65-
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
66-
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
67-
for (int i = 0; i < (int) embd_inp.size(); i++) {
68-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
118+
// tokenize the prompts and trim
119+
std::vector<std::vector<int32_t>> inputs;
120+
for (const auto & prompt : prompts) {
121+
auto inp = ::llama_tokenize(ctx, prompt, true);
122+
if (inp.size() > n_batch) {
123+
inp.resize(n_batch);
69124
}
70-
fprintf(stderr, "\n");
125+
inputs.push_back(inp);
71126
}
72127

73-
if (embd_inp.size() > (size_t)n_ctx) {
74-
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
75-
__func__, embd_inp.size(), n_ctx);
76-
return 1;
77-
}
78-
79-
while (!embd_inp.empty()) {
80-
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
81-
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
82-
fprintf(stderr, "%s : failed to eval\n", __func__);
83-
return 1;
128+
// tokenization stats
129+
if (params.verbose_prompt) {
130+
for (int i = 0; i < (int) inputs.size(); i++) {
131+
fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
132+
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
133+
for (int j = 0; j < (int) inputs[i].size(); j++) {
134+
fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str());
135+
}
136+
fprintf(stderr, "\n\n");
84137
}
85-
n_past += n_tokens;
86-
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
87138
}
88139

140+
// initialize batch
141+
const int n_prompts = prompts.size();
142+
struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
143+
144+
// allocate output
89145
const int n_embd = llama_n_embd(model);
90-
auto * embeddings = llama_get_embeddings(ctx);
146+
std::vector<float> embeddings(n_prompts * n_embd, 0);
147+
float * emb = embeddings.data();
148+
149+
// break into batches
150+
int p = 0; // number of prompts processed already
151+
int s = 0; // number of prompts in current batch
152+
for (int k = 0; k < n_prompts; k++) {
153+
// clamp to n_batch tokens
154+
auto & inp = inputs[k];
155+
const uint64_t n_toks = inp.size();
156+
157+
// encode if at capacity
158+
if (batch.n_tokens + n_toks > n_batch) {
159+
float * out = emb + p * n_embd;
160+
batch_decode(ctx, batch, out, s, n_embd);
161+
batch.n_tokens = 0;
162+
p += s;
163+
s = 0;
164+
}
91165

92-
// l2-normalize embeddings
93-
float norm = 0;
94-
for (int i = 0; i < n_embd; i++) {
95-
norm += embeddings[i] * embeddings[i];
96-
}
97-
norm = sqrt(norm);
98-
for (int i = 0; i < n_embd; i++) {
99-
embeddings[i] /= norm;
166+
// add to batch
167+
batch_add_seq(batch, inp, s);
168+
s += 1;
100169
}
101170

171+
// final batch
172+
float * out = emb + p * n_embd;
173+
batch_decode(ctx, batch, out, s, n_embd);
174+
175+
// print first embedding
176+
fprintf(stderr, "\nfirst embedding:\n");
102177
for (int i = 0; i < n_embd; i++) {
103-
printf("%f ", embeddings[i]);
178+
fprintf(stderr, "%f ", emb[i]);
104179
}
105-
printf("\n");
180+
fprintf(stderr, "\n");
106181

182+
// clean up
107183
llama_print_timings(ctx);
108184
llama_free(ctx);
109185
llama_free_model(model);
110-
111186
llama_backend_free();
112187

113188
return 0;

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class LLM:
4040
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
4141
EXPERT_COUNT = "{arch}.expert_count"
4242
EXPERT_USED_COUNT = "{arch}.expert_used_count"
43+
POOLING_LAYER = "{arch}.pooling_layer"
4344

4445
class Attention:
4546
HEAD_COUNT = "{arch}.attention.head_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,9 @@ def add_layer_norm_rms_eps(self, value: float) -> None:
360360
def add_causal_attention(self, value: bool) -> None:
361361
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
362362

363+
def add_pooling_layer(self, value: bool) -> None:
364+
self.add_bool(Keys.LLM.POOLING_LAYER.format(arch=self.arch), value)
365+
363366
def add_rope_dimension_count(self, count: int) -> None:
364367
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
365368

llama.cpp

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ enum llm_kv {
254254
LLM_KV_TENSOR_DATA_LAYOUT,
255255
LLM_KV_EXPERT_COUNT,
256256
LLM_KV_EXPERT_USED_COUNT,
257+
LLM_KV_POOLING_LAYER,
257258

258259
LLM_KV_ATTENTION_HEAD_COUNT,
259260
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -311,6 +312,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
311312
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
312313
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
313314
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
315+
{ LLM_KV_POOLING_LAYER, "%s.pooling_layer" },
314316

315317
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
316318
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -1524,6 +1526,7 @@ struct llama_hparams {
15241526
float f_max_alibi_bias;
15251527

15261528
bool causal_attn = true;
1529+
bool pooling_layer = false;
15271530

15281531

15291532
bool operator!=(const llama_hparams & other) const {
@@ -1586,6 +1589,7 @@ struct llama_cparams {
15861589

15871590
bool mul_mat_q;
15881591
bool offload_kqv;
1592+
bool do_pooling;
15891593

15901594
ggml_backend_sched_eval_callback cb_eval;
15911595
void * cb_eval_user_data;
@@ -1881,7 +1885,7 @@ struct llama_context {
18811885
struct ggml_tensor * inp_pos; // I32 [n_batch]
18821886
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
18831887
struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
1884-
struct ggml_tensor * inp_sum; // F32 [1, n_batch]
1888+
struct ggml_tensor * inp_sum; // F32 [n_batch, n_batch]
18851889

18861890
#ifdef GGML_USE_MPI
18871891
ggml_mpi_context * ctx_mpi = NULL;
@@ -3038,6 +3042,7 @@ static void llm_load_hparams(
30383042
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
30393043
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
30403044
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
3045+
ml.get_key(LLM_KV_POOLING_LAYER, hparams.pooling_layer);
30413046

30423047
switch (hparams.n_layer) {
30433048
case 3:
@@ -4845,6 +4850,7 @@ struct llm_build_context {
48454850

48464851
const bool do_rope_shift;
48474852
const bool causal_attn;
4853+
const bool do_pooling;
48484854

48494855
const llm_build_cb & cb;
48504856

@@ -4889,6 +4895,7 @@ struct llm_build_context {
48894895
n_orig_ctx (cparams.n_yarn_orig_ctx),
48904896
do_rope_shift (worst_case || kv_self.has_shift),
48914897
causal_attn (hparams.causal_attn),
4898+
do_pooling (hparams.pooling_layer && cparams.do_pooling),
48924899
cb (cb),
48934900
buf_compute_meta (lctx.buf_compute_meta) {
48944901
// all initializations should be done in init()
@@ -5737,14 +5744,14 @@ struct llm_build_context {
57375744

57385745
const int64_t n_embd_head = hparams.n_embd_head_v;
57395746
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
5740-
GGML_ASSERT(n_embd_head == hparams.n_rot);
57415747

57425748
struct ggml_tensor * cur;
57435749
struct ggml_tensor * inpL;
57445750

57455751
// get input vectors with right size
5752+
const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
57465753
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
5747-
struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0);
5754+
struct ggml_tensor * inp_sum = ggml_view_2d(ctx0, lctx.inp_sum, n_tokens, n_tokens, stride1, 0);
57485755

57495756
// construct input embeddings (token, type, position)
57505757
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
@@ -5817,8 +5824,10 @@ struct llm_build_context {
58175824
// final output
58185825
cur = inpL;
58195826

5820-
// pooling
5821-
cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
5827+
// pooling layer
5828+
if (do_pooling) {
5829+
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_sum);
5830+
}
58225831
cb(cur, "result_embed", -1);
58235832

58245833
ggml_build_forward_expand(gf, cur);
@@ -7384,6 +7393,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
73847393
data[i] = lctx.kv_self.cells[i].delta;
73857394
}
73867395
}
7396+
7397+
if (hparams.pooling_layer && cparams.do_pooling) {
7398+
const int64_t n_tokens = batch.n_tokens;
7399+
7400+
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
7401+
float * data = (float *) lctx.inp_sum->data;
7402+
7403+
memset(lctx.inp_sum->data, 0, batch.n_tokens * batch.n_tokens * ggml_element_size(lctx.inp_sum));
7404+
for (int i = 0; i < n_tokens; ++i) {
7405+
const llama_seq_id seq_id = batch.seq_id[i][0];
7406+
data[seq_id*n_tokens + i] = 1.0f;
7407+
}
7408+
}
73877409
}
73887410

73897411
// decode a batch of tokens by evaluating the transformer
@@ -7616,10 +7638,11 @@ static int llama_decode_internal(
76167638
auto & embedding_out = lctx.embedding;
76177639

76187640
const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
7641+
const int64_t embed_size = res ? n_embd : n_embd * n_tokens;
76197642

7620-
embedding_out.resize(n_embd);
7643+
embedding_out.resize(embed_size);
76217644
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
7622-
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float));
7645+
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), embed_size*sizeof(float));
76237646
ggml_backend_synchronize(embeddings_backend);
76247647
}
76257648

@@ -10930,6 +10953,7 @@ struct llama_context_params llama_context_default_params() {
1093010953
/*.logits_all =*/ false,
1093110954
/*.embedding =*/ false,
1093210955
/*.offload_kqv =*/ true,
10956+
/*.do_pooling =*/ true,
1093310957
};
1093410958

1093510959
return result;
@@ -11085,6 +11109,7 @@ struct llama_context * llama_new_context_with_model(
1108511109
cparams.yarn_beta_slow = params.yarn_beta_slow;
1108611110
cparams.mul_mat_q = params.mul_mat_q;
1108711111
cparams.offload_kqv = params.offload_kqv;
11112+
cparams.do_pooling = params.do_pooling;
1108811113

1108911114
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
1109011115
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -11232,7 +11257,7 @@ struct llama_context * llama_new_context_with_model(
1123211257
// resized during inference, reserve maximum
1123311258
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
1123411259

11235-
if (params.embedding){
11260+
if (params.embedding) {
1123611261
ctx->embedding.resize(hparams.n_embd);
1123711262
}
1123811263

@@ -11250,7 +11275,7 @@ struct llama_context * llama_new_context_with_model(
1125011275
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
1125111276
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
1125211277
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
11253-
ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch);
11278+
ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
1125411279

1125511280
ggml_set_name(ctx->inp_tokens, "inp_tokens");
1125611281
ggml_set_name(ctx->inp_embd, "inp_embd");
@@ -12108,6 +12133,10 @@ float * llama_get_embeddings(struct llama_context * ctx) {
1210812133
return ctx->embedding.data();
1210912134
}
1211012135

12136+
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
12137+
return ctx->embedding.data() + i*ctx->model.hparams.n_embd;
12138+
}
12139+
1211112140
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
1211212141
return model->vocab.id_to_token[token].text.c_str();
1211312142
}

llama.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ extern "C" {
236236
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
237237
bool embedding; // embedding mode only
238238
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
239+
bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
239240
};
240241

241242
// model quantization parameters
@@ -628,6 +629,10 @@ extern "C" {
628629
// shape: [n_embd] (1-dimensional)
629630
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
630631

632+
// Get the embeddings for the ith token
633+
// llama_get_embeddings(ctx) + i*n_embd
634+
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
635+
631636
//
632637
// Vocab
633638
//

0 commit comments

Comments
 (0)