Skip to content

Commit 0fd6c1f

Browse files
committed
embedding : print cosine similarity (#899)
1 parent 19885d2 commit 0fd6c1f

File tree

4 files changed

+36
-25
lines changed

4 files changed

+36
-25
lines changed

common/common.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,3 +1877,16 @@ void llama_embd_normalize(const float * inp, float * out, int n) {
18771877
}
18781878
}
18791879

1880+
float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n){
1881+
double sum = 0.0;
1882+
double sum1 = 0.0;
1883+
double sum2 = 0.0;
1884+
1885+
for (int i = 0; i < n; i++) {
1886+
sum += embd1[i] * embd2[i];
1887+
sum1 += embd1[i] * embd1[i];
1888+
sum2 += embd2[i] * embd2[i];
1889+
}
1890+
1891+
return sum / (sqrt(sum1) * sqrt(sum2));
1892+
}

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,4 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40
268268

269269
void llama_embd_normalize(const float * inp, float * out, int n);
270270

271+
float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);

examples/embedding/embedding.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,25 @@ int main(int argc, char ** argv) {
168168
batch_decode(ctx, batch, out, s, n_embd);
169169

170170
// print first 3 embeddings
171+
fprintf(stdout, "\n");
171172
for (int j = 0; j < std::min(3, n_prompts); j++) {
172-
fprintf(stderr, "embedding %d: ", j);
173-
for (int i = 0; i < n_embd; i++) {
174-
fprintf(stderr, "%f ", emb[j * n_embd + i]);
173+
fprintf(stdout, "embedding %d: ", j);
174+
for (int i = 0; i < std::min(16, n_embd); i++) {
175+
fprintf(stdout, "%f ", emb[j * n_embd + i]);
175176
}
176-
fprintf(stderr, "\n\n");
177+
fprintf(stdout, "\n");
178+
}
179+
180+
// print cosine similarity matrix
181+
fprintf(stdout, "\n");
182+
printf("cosine similarity matrix:\n\n");
183+
for (int i = 0; i < n_prompts; i++) {
184+
for (int j = 0; j < n_prompts; j++) {
185+
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
186+
fprintf(stdout, "%6.2f ", sim);
187+
}
188+
fprintf(stdout, "\n");
177189
}
178-
fprintf(stderr, "\n");
179190

180191
// clean up
181192
llama_print_timings(ctx);

examples/gritlm/gritlm.cpp

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,6 @@
66

77
// #define GRIT_DEBUG
88

9-
static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
10-
float dot = 0.0f;
11-
for (uint64_t i = 0; i < v1.size(); ++i) {
12-
dot += v1[i] * v2[i];
13-
}
14-
return dot;
15-
}
16-
17-
static float norm(const std::vector<float> & v) {
18-
return std::sqrt(dot_product(v, v));
19-
}
20-
21-
static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) {
22-
return dot_product(v1, v2) / (norm(v1) * norm(v2));
23-
}
24-
259
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
2610
std::vector<std::vector<float>> result;
2711

@@ -203,10 +187,12 @@ int main(int argc, char * argv[]) {
203187
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
204188
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
205189

206-
const float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
207-
const float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
208-
const float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]);
209-
const float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]);
190+
const int n_embd = llama_n_embd(mdl);
191+
192+
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
193+
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
194+
const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
195+
const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
210196

211197
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
212198
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);

0 commit comments

Comments
 (0)