Skip to content

Commit 373f916

Browse files
redlion0929ggerganov
authored andcommitted
server : normalize embeddings (ggml-org#5956)
* output normalize embedding in '/v1/embeddings' * common : reuse llama_embd_normalize * common : better normalize impl --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent f0fa237 commit 373f916

File tree

4 files changed

+30
-14
lines changed

4 files changed

+30
-14
lines changed

common/common.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,3 +1852,18 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
18521852

18531853
printf("\n=== Done dumping\n");
18541854
}
1855+
1856+
void llama_embd_normalize(const float * inp, float * out, int n) {
1857+
double sum = 0.0;
1858+
for (int i = 0; i < n; i++) {
1859+
sum += inp[i] * inp[i];
1860+
}
1861+
sum = sqrt(sum);
1862+
1863+
const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
1864+
1865+
for (int i = 0; i < n; i++) {
1866+
out[i] = inp[i] * norm;
1867+
}
1868+
}
1869+

common/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,10 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
260260

261261
// Dump the KV cache view showing individual sequences in each cell (long output).
262262
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
263+
264+
//
265+
// Embedding utils
266+
//
267+
268+
void llama_embd_normalize(const float * inp, float * out, int n);
269+

examples/embedding/embedding.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,6 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
2323
}
2424
}
2525

26-
static void normalize(const float * vec, float * out, int n) {
27-
float norm = 0;
28-
for (int i = 0; i < n; i++) {
29-
norm += vec[i] * vec[i];
30-
}
31-
norm = sqrt(norm);
32-
for (int i = 0; i < n; i++) {
33-
out[i] = vec[i] / norm;
34-
}
35-
}
36-
3726
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
3827
// clear previous kv_cache values (irrelevant for embeddings)
3928
llama_kv_cache_clear(ctx);
@@ -44,7 +33,6 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4433
fprintf(stderr, "%s : failed to decode\n", __func__);
4534
}
4635

47-
// normalize on copy
4836
for (int i = 0; i < batch.n_tokens; i++) {
4937
if (!batch.logits[i]) {
5038
continue;
@@ -61,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
6149
}
6250

6351
float * out = output + batch.seq_id[i][0] * n_embd;
64-
normalize(embd, out, n_embd);
52+
llama_embd_normalize(embd, out, n_embd);
6553
}
6654
}
6755

examples/server/server.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,8 @@ struct server_context {
13271327

13281328
const int n_embd = llama_n_embd(model);
13291329

1330+
std::vector<float> embd_res(n_embd, 0.0f);
1331+
13301332
for (int i = 0; i < batch.n_tokens; ++i) {
13311333
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
13321334
continue;
@@ -1350,8 +1352,10 @@ struct server_context {
13501352
continue;
13511353
}
13521354

1355+
llama_embd_normalize(embd, embd_res.data(), n_embd);
1356+
13531357
res.data = json {
1354-
{"embedding", std::vector<float>(embd, embd + n_embd)},
1358+
{"embedding", embd_res},
13551359
};
13561360
}
13571361

@@ -3354,6 +3358,8 @@ int main(int argc, char ** argv) {
33543358
// get the result
33553359
server_task_result result = ctx_server.queue_results.recv(id_task);
33563360
ctx_server.queue_results.remove_waiting_task_id(id_task);
3361+
3362+
// append to the responses
33573363
responses.push_back(result.data);
33583364
}
33593365

0 commit comments

Comments
 (0)