@@ -31,25 +31,47 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
31
31
}
32
32
33
33
static void batch_decode (llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
34
+ const enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
35
+ const struct llama_model * model = llama_get_model (ctx);
36
+
34
37
// clear previous kv_cache values (irrelevant for embeddings)
35
38
llama_kv_cache_clear (ctx);
36
39
37
40
// run model
38
41
fprintf (stderr, " %s: n_tokens = %d, n_seq = %d\n " , __func__, batch.n_tokens , n_seq);
39
- if (llama_decode (ctx, batch) < 0 ) {
40
- fprintf (stderr, " %s : failed to decode\n " , __func__);
42
+ if (llama_model_has_encoder (model) && !llama_model_has_decoder (model)) {
43
+ // encoder-only model
44
+ if (llama_encode (ctx, batch) < 0 ) {
45
+ fprintf (stderr, " %s : failed to encode\n " , __func__);
46
+ }
47
+ } else if (!llama_model_has_encoder (model) && llama_model_has_decoder (model)) {
48
+ // decoder-only model
49
+ if (llama_decode (ctx, batch) < 0 ) {
50
+ fprintf (stderr, " %s : failed to decode\n " , __func__);
51
+ }
41
52
}
42
53
43
54
for (int i = 0 ; i < batch.n_tokens ; i++) {
44
55
if (!batch.logits [i]) {
45
56
continue ;
46
57
}
47
58
48
- // try to get sequence embeddings - supported only when pooling_type is not NONE
49
- const float * embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
50
- GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
59
+ const float * embd = nullptr ;
60
+ int embd_pos = 0 ;
61
+
62
+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
63
+ // try to get token embeddings
64
+ embd = llama_get_embeddings_ith (ctx, i);
65
+ embd_pos = i;
66
+ GGML_ASSERT (embd != NULL && " failed to get token embeddings" );
67
+ } else {
68
+ // try to get sequence embeddings - supported only when pooling_type is not NONE
69
+ embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
70
+ embd_pos = batch.seq_id [i][0 ];
71
+ GGML_ASSERT (embd != NULL && " failed to get sequence embeddings" );
72
+ }
51
73
52
- float * out = output + batch. seq_id [i][ 0 ] * n_embd;
74
+ float * out = output + embd_pos * n_embd;
53
75
llama_embd_normalize (embd, out, n_embd, embd_norm);
54
76
}
55
77
}
@@ -93,8 +115,9 @@ int main(int argc, char ** argv) {
93
115
const int n_ctx = llama_n_ctx (ctx);
94
116
95
117
const enum llama_pooling_type pooling_type = llama_pooling_type (ctx);
96
- if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
97
- fprintf (stderr, " %s: error: pooling type NONE not supported\n " , __func__);
118
+
119
+ if (llama_model_has_encoder (model) && llama_model_has_decoder (model)) {
120
+ fprintf (stderr, " %s: error: computing embeddings in encoder-decoder models is not supported\n " , __func__);
98
121
return 1 ;
99
122
}
100
123
@@ -153,13 +176,23 @@ int main(int argc, char ** argv) {
153
176
const int n_prompts = prompts.size ();
154
177
struct llama_batch batch = llama_batch_init (n_batch, 0 , 1 );
155
178
179
+ // count number of embeddings
180
+ int n_embd_count = 0 ;
181
+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
182
+ for (int k = 0 ; k < n_prompts; k++) {
183
+ n_embd_count += inputs[k].size ();
184
+ }
185
+ } else {
186
+ n_embd_count = n_prompts;
187
+ }
188
+
156
189
// allocate output
157
190
const int n_embd = llama_n_embd (model);
158
- std::vector<float > embeddings (n_prompts * n_embd, 0 );
191
+ std::vector<float > embeddings (n_embd_count * n_embd, 0 );
159
192
float * emb = embeddings.data ();
160
193
161
194
// break into batches
162
- int p = 0 ; // number of prompts processed already
195
+ int e = 0 ; // number of embeddings already stored
163
196
int s = 0 ; // number of prompts in current batch
164
197
for (int k = 0 ; k < n_prompts; k++) {
165
198
// clamp to n_batch tokens
@@ -169,11 +202,11 @@ int main(int argc, char ** argv) {
169
202
170
203
// encode if at capacity
171
204
if (batch.n_tokens + n_toks > n_batch) {
172
- float * out = emb + p * n_embd;
205
+ float * out = emb + e * n_embd;
173
206
batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
174
- llama_batch_clear (batch);
175
- p += s;
207
+ e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
176
208
s = 0 ;
209
+ llama_batch_clear (batch);
177
210
}
178
211
179
212
// add to batch
@@ -182,39 +215,62 @@ int main(int argc, char ** argv) {
182
215
}
183
216
184
217
// final batch
185
- float * out = emb + p * n_embd;
218
+ float * out = emb + e * n_embd;
186
219
batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
187
220
188
221
if (params.embd_out .empty ()) {
189
- // print the first part of the embeddings or for a single prompt, the full embedding
190
222
fprintf (stdout, " \n " );
191
- for (int j = 0 ; j < n_prompts; j++) {
192
- fprintf (stdout, " embedding %d: " , j);
193
- for (int i = 0 ; i < (n_prompts > 1 ? std::min (16 , n_embd) : n_embd); i++) {
194
- if (params.embd_normalize == 0 ) {
195
- fprintf (stdout, " %6.0f " , emb[j * n_embd + i]);
196
- } else {
197
- fprintf (stdout, " %9.6f " , emb[j * n_embd + i]);
223
+
224
+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
225
+ for (int j = 0 ; j < n_embd_count; j++) {
226
+ fprintf (stdout, " embedding %d: " , j);
227
+ for (int i = 0 ; i < std::min (3 , n_embd); i++) {
228
+ if (params.embd_normalize == 0 ) {
229
+ fprintf (stdout, " %6.0f " , emb[j * n_embd + i]);
230
+ } else {
231
+ fprintf (stdout, " %9.6f " , emb[j * n_embd + i]);
232
+ }
233
+ }
234
+ fprintf (stdout, " ... " );
235
+ for (int i = n_embd - 3 ; i < n_embd; i++) {
236
+ if (params.embd_normalize == 0 ) {
237
+ fprintf (stdout, " %6.0f " , emb[j * n_embd + i]);
238
+ } else {
239
+ fprintf (stdout, " %9.6f " , emb[j * n_embd + i]);
240
+ }
198
241
}
242
+ fprintf (stdout, " \n " );
199
243
}
200
- fprintf (stdout, " \n " );
201
- }
202
-
203
- // print cosine similarity matrix
204
- if (n_prompts > 1 ) {
205
- fprintf (stdout, " \n " );
206
- printf (" cosine similarity matrix:\n\n " );
207
- for (int i = 0 ; i < n_prompts; i++) {
208
- fprintf (stdout, " %6.6s " , prompts[i].c_str ());
244
+ } else {
245
+ // print the first part of the embeddings or for a single prompt, the full embedding
246
+ for (int j = 0 ; j < n_prompts; j++) {
247
+ fprintf (stdout, " embedding %d: " , j);
248
+ for (int i = 0 ; i < (n_prompts > 1 ? std::min (16 , n_embd) : n_embd); i++) {
249
+ if (params.embd_normalize == 0 ) {
250
+ fprintf (stdout, " %6.0f " , emb[j * n_embd + i]);
251
+ } else {
252
+ fprintf (stdout, " %9.6f " , emb[j * n_embd + i]);
253
+ }
254
+ }
255
+ fprintf (stdout, " \n " );
209
256
}
210
- fprintf (stdout, " \n " );
211
- for (int i = 0 ; i < n_prompts; i++) {
212
- for (int j = 0 ; j < n_prompts; j++) {
213
- float sim = llama_embd_similarity_cos (emb + i * n_embd, emb + j * n_embd, n_embd);
214
- fprintf (stdout, " %6.2f " , sim);
257
+
258
+ // print cosine similarity matrix
259
+ if (n_prompts > 1 ) {
260
+ fprintf (stdout, " \n " );
261
+ printf (" cosine similarity matrix:\n\n " );
262
+ for (int i = 0 ; i < n_prompts; i++) {
263
+ fprintf (stdout, " %6.6s " , prompts[i].c_str ());
215
264
}
216
- fprintf (stdout, " %1.10s" , prompts[i].c_str ());
217
265
fprintf (stdout, " \n " );
266
+ for (int i = 0 ; i < n_prompts; i++) {
267
+ for (int j = 0 ; j < n_prompts; j++) {
268
+ float sim = llama_embd_similarity_cos (emb + i * n_embd, emb + j * n_embd, n_embd);
269
+ fprintf (stdout, " %6.2f " , sim);
270
+ }
271
+ fprintf (stdout, " %1.10s" , prompts[i].c_str ());
272
+ fprintf (stdout, " \n " );
273
+ }
218
274
}
219
275
}
220
276
}
@@ -233,23 +289,23 @@ int main(int argc, char ** argv) {
233
289
}
234
290
fprintf (stdout, notArray ? " ]\n }" : " ]" );
235
291
j++;
236
- if (j < n_prompts ) fprintf (stdout, notArray ? " ,\n " : " ," ); else break ;
292
+ if (j < n_embd_count ) fprintf (stdout, notArray ? " ,\n " : " ," ); else break ;
237
293
}
238
294
fprintf (stdout, notArray ? " \n ]" : " ]\n " );
239
295
240
296
if (params.embd_out == " json+" && n_prompts > 1 ) {
241
297
fprintf (stdout, " ,\n \" cosineSimilarity\" : [\n " );
242
- for (int i = 0 ;;) { // at least two iteration (n_prompts > 1)
298
+ for (int i = 0 ;;) { // at least two iteration (n_embd_count > 1)
243
299
fprintf (stdout, " [" );
244
- for (int j = 0 ;;) { // at least two iteration (n_prompts > 1)
300
+ for (int j = 0 ;;) { // at least two iteration (n_embd_count > 1)
245
301
float sim = llama_embd_similarity_cos (emb + i * n_embd, emb + j * n_embd, n_embd);
246
302
fprintf (stdout, " %6.2f" , sim);
247
303
j++;
248
- if (j < n_prompts ) fprintf (stdout, " , " ); else break ;
304
+ if (j < n_embd_count ) fprintf (stdout, " , " ); else break ;
249
305
}
250
306
fprintf (stdout, " ]" );
251
307
i++;
252
- if (i < n_prompts ) fprintf (stdout, " ,\n " ); else break ;
308
+ if (i < n_embd_count ) fprintf (stdout, " ,\n " ); else break ;
253
309
}
254
310
fprintf (stdout, " \n ]" );
255
311
}
0 commit comments