Skip to content

Commit 03bf161

Browse files
iamlemecggerganov
andauthored
llama : support batched embeddings (#5466)
* batched embedding: pool outputs by sequence id. updated embedding example * bring back non-causal attention * embd : minor improvements * llama : minor --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent ad014bb commit 03bf161

File tree

6 files changed

+161
-52
lines changed

6 files changed

+161
-52
lines changed

convert-hf-to-gguf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,7 @@ def set_gguf_parameters(self):
16481648
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
16491649
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
16501650
self.gguf_writer.add_causal_attention(False)
1651+
self.gguf_writer.add_pooling_layer(True)
16511652
self.gguf_writer.add_file_type(self.ftype)
16521653

16531654
def set_vocab(self):

examples/embedding/embedding.cpp

Lines changed: 106 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,51 @@
77
#pragma warning(disable: 4244 4267) // possible loss of data
88
#endif
99

10+
static std::vector<std::string> split_lines(const std::string & s) {
11+
std::string line;
12+
std::vector<std::string> lines;
13+
std::stringstream ss(s);
14+
while (std::getline(ss, line)) {
15+
lines.push_back(line);
16+
}
17+
return lines;
18+
}
19+
20+
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
21+
for (size_t i = 0; i < tokens.size(); i++) {
22+
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
23+
}
24+
}
25+
26+
static void normalize(float * vec, float * out, int n) {
27+
float norm = 0;
28+
for (int i = 0; i < n; i++) {
29+
norm += vec[i] * vec[i];
30+
}
31+
norm = sqrt(norm);
32+
for (int i = 0; i < n; i++) {
33+
out[i] = vec[i] / norm;
34+
}
35+
}
36+
37+
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
38+
// clear previous kv_cache values (irrelevant for embeddings)
39+
llama_kv_cache_clear(ctx);
40+
41+
// run model
42+
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
43+
if (llama_decode(ctx, batch) < 0) {
44+
fprintf(stderr, "%s : failed to decode\n", __func__);
45+
}
46+
47+
// normalize on copy
48+
for (int k = 0; k < n_seq; k++) {
49+
float * emb = llama_get_embeddings_ith(ctx, k);
50+
float * out = output + k * n_embd;
51+
normalize(emb, out, n_embd);
52+
}
53+
}
54+
1055
int main(int argc, char ** argv) {
1156
gpt_params params;
1257

@@ -55,59 +100,84 @@ int main(int argc, char ** argv) {
55100
fprintf(stderr, "%s\n", get_system_info(params).c_str());
56101
}
57102

58-
int n_past = 0;
103+
// split the prompt into lines
104+
std::vector<std::string> prompts = split_lines(params.prompt);
59105

60-
// tokenize the prompt
61-
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
106+
// max batch size
107+
const uint64_t n_batch = params.n_batch;
108+
GGML_ASSERT(params.n_batch == params.n_ctx);
62109

63-
if (params.verbose_prompt) {
64-
fprintf(stderr, "\n");
65-
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
66-
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
67-
for (int i = 0; i < (int) embd_inp.size(); i++) {
68-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
110+
// tokenize the prompts and trim
111+
std::vector<std::vector<int32_t>> inputs;
112+
for (const auto & prompt : prompts) {
113+
auto inp = ::llama_tokenize(ctx, prompt, true);
114+
if (inp.size() > n_batch) {
115+
inp.resize(n_batch);
69116
}
70-
fprintf(stderr, "\n");
117+
inputs.push_back(inp);
71118
}
72119

73-
if (embd_inp.size() > (size_t)n_ctx) {
74-
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
75-
__func__, embd_inp.size(), n_ctx);
76-
return 1;
77-
}
78-
79-
while (!embd_inp.empty()) {
80-
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
81-
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
82-
fprintf(stderr, "%s : failed to eval\n", __func__);
83-
return 1;
120+
// tokenization stats
121+
if (params.verbose_prompt) {
122+
for (int i = 0; i < (int) inputs.size(); i++) {
123+
fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
124+
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
125+
for (int j = 0; j < (int) inputs[i].size(); j++) {
126+
fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str());
127+
}
128+
fprintf(stderr, "\n\n");
84129
}
85-
n_past += n_tokens;
86-
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
87130
}
88131

132+
// initialize batch
133+
const int n_prompts = prompts.size();
134+
struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
135+
136+
// allocate output
89137
const int n_embd = llama_n_embd(model);
90-
auto * embeddings = llama_get_embeddings(ctx);
138+
std::vector<float> embeddings(n_prompts * n_embd, 0);
139+
float * emb = embeddings.data();
140+
141+
// break into batches
142+
int p = 0; // number of prompts processed already
143+
int s = 0; // number of prompts in current batch
144+
for (int k = 0; k < n_prompts; k++) {
145+
// clamp to n_batch tokens
146+
auto & inp = inputs[k];
147+
const uint64_t n_toks = inp.size();
148+
149+
// encode if at capacity
150+
if (batch.n_tokens + n_toks > n_batch) {
151+
float * out = emb + p * n_embd;
152+
batch_decode(ctx, batch, out, s, n_embd);
153+
llama_batch_clear(batch);
154+
p += s;
155+
s = 0;
156+
}
91157

92-
// l2-normalize embeddings
93-
float norm = 0;
94-
for (int i = 0; i < n_embd; i++) {
95-
norm += embeddings[i] * embeddings[i];
96-
}
97-
norm = sqrt(norm);
98-
for (int i = 0; i < n_embd; i++) {
99-
embeddings[i] /= norm;
158+
// add to batch
159+
batch_add_seq(batch, inp, s);
160+
s += 1;
100161
}
101162

102-
for (int i = 0; i < n_embd; i++) {
103-
printf("%f ", embeddings[i]);
163+
// final batch
164+
float * out = emb + p * n_embd;
165+
batch_decode(ctx, batch, out, s, n_embd);
166+
167+
// print first 3 embeddings
168+
for (int j = 0; j < std::min(3, n_prompts); j++) {
169+
fprintf(stderr, "embedding %d: ", j);
170+
for (int i = 0; i < n_embd; i++) {
171+
fprintf(stderr, "%f ", emb[j * n_embd + i]);
172+
}
173+
fprintf(stderr, "\n\n");
104174
}
105-
printf("\n");
175+
fprintf(stderr, "\n");
106176

177+
// clean up
107178
llama_print_timings(ctx);
108179
llama_free(ctx);
109180
llama_free_model(model);
110-
111181
llama_backend_free();
112182

113183
return 0;

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class LLM:
4040
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
4141
EXPERT_COUNT = "{arch}.expert_count"
4242
EXPERT_USED_COUNT = "{arch}.expert_used_count"
43+
POOLING_LAYER = "{arch}.pooling_layer"
4344

4445
class Attention:
4546
HEAD_COUNT = "{arch}.attention.head_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,9 @@ def add_layer_norm_rms_eps(self, value: float) -> None:
360360
def add_causal_attention(self, value: bool) -> None:
361361
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
362362

363+
def add_pooling_layer(self, value: bool) -> None:
364+
self.add_bool(Keys.LLM.POOLING_LAYER.format(arch=self.arch), value)
365+
363366
def add_rope_dimension_count(self, count: int) -> None:
364367
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
365368

0 commit comments

Comments
 (0)