|
7 | 7 | #pragma warning(disable: 4244 4267) // possible loss of data
|
8 | 8 | #endif
|
9 | 9 |
|
| 10 | +static std::vector<std::string> split_lines(const std::string & s) { |
| 11 | + std::string line; |
| 12 | + std::vector<std::string> lines; |
| 13 | + std::stringstream ss(s); |
| 14 | + while (std::getline(ss, line)) { |
| 15 | + lines.push_back(line); |
| 16 | + } |
| 17 | + return lines; |
| 18 | +} |
| 19 | + |
| 20 | +static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) { |
| 21 | + for (size_t i = 0; i < tokens.size(); i++) { |
| 22 | + llama_batch_add(batch, tokens[i], i, { seq_id }, false); |
| 23 | + } |
| 24 | +} |
| 25 | + |
| 26 | +static void normalize(float * vec, float * out, int n) { |
| 27 | + float norm = 0; |
| 28 | + for (int i = 0; i < n; i++) { |
| 29 | + norm += vec[i] * vec[i]; |
| 30 | + } |
| 31 | + norm = sqrt(norm); |
| 32 | + for (int i = 0; i < n; i++) { |
| 33 | + out[i] = vec[i] / norm; |
| 34 | + } |
| 35 | +} |
| 36 | + |
| 37 | +static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { |
| 38 | + // clear previous kv_cache values (irrelevant for embeddings) |
| 39 | + llama_kv_cache_clear(ctx); |
| 40 | + |
| 41 | + // run model |
| 42 | + fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); |
| 43 | + if (llama_decode(ctx, batch) < 0) { |
| 44 | + fprintf(stderr, "%s : failed to decode\n", __func__); |
| 45 | + } |
| 46 | + |
| 47 | + // normalize on copy |
| 48 | + for (int k = 0; k < n_seq; k++) { |
| 49 | + float * emb = llama_get_embeddings_ith(ctx, k); |
| 50 | + float * out = output + k * n_embd; |
| 51 | + normalize(emb, out, n_embd); |
| 52 | + } |
| 53 | +} |
| 54 | + |
10 | 55 | int main(int argc, char ** argv) {
|
11 | 56 | gpt_params params;
|
12 | 57 |
|
@@ -55,59 +100,84 @@ int main(int argc, char ** argv) {
|
55 | 100 | fprintf(stderr, "%s\n", get_system_info(params).c_str());
|
56 | 101 | }
|
57 | 102 |
|
58 |
| - int n_past = 0; |
| 103 | + // split the prompt into lines |
| 104 | + std::vector<std::string> prompts = split_lines(params.prompt); |
59 | 105 |
|
60 |
| - // tokenize the prompt |
61 |
| - auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); |
| 106 | + // max batch size |
| 107 | + const uint64_t n_batch = params.n_batch; |
| 108 | + GGML_ASSERT(params.n_batch == params.n_ctx); |
62 | 109 |
|
63 |
| - if (params.verbose_prompt) { |
64 |
| - fprintf(stderr, "\n"); |
65 |
| - fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); |
66 |
| - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); |
67 |
| - for (int i = 0; i < (int) embd_inp.size(); i++) { |
68 |
| - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); |
| 110 | + // tokenize the prompts and trim |
| 111 | + std::vector<std::vector<int32_t>> inputs; |
| 112 | + for (const auto & prompt : prompts) { |
| 113 | + auto inp = ::llama_tokenize(ctx, prompt, true); |
| 114 | + if (inp.size() > n_batch) { |
| 115 | + inp.resize(n_batch); |
69 | 116 | }
|
70 |
| - fprintf(stderr, "\n"); |
| 117 | + inputs.push_back(inp); |
71 | 118 | }
|
72 | 119 |
|
73 |
| - if (embd_inp.size() > (size_t)n_ctx) { |
74 |
| - fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n", |
75 |
| - __func__, embd_inp.size(), n_ctx); |
76 |
| - return 1; |
77 |
| - } |
78 |
| - |
79 |
| - while (!embd_inp.empty()) { |
80 |
| - int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); |
81 |
| - if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) { |
82 |
| - fprintf(stderr, "%s : failed to eval\n", __func__); |
83 |
| - return 1; |
| 120 | + // tokenization stats |
| 121 | + if (params.verbose_prompt) { |
| 122 | + for (int i = 0; i < (int) inputs.size(); i++) { |
| 123 | + fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str()); |
| 124 | + fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size()); |
| 125 | + for (int j = 0; j < (int) inputs[i].size(); j++) { |
| 126 | + fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str()); |
| 127 | + } |
| 128 | + fprintf(stderr, "\n\n"); |
84 | 129 | }
|
85 |
| - n_past += n_tokens; |
86 |
| - embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens); |
87 | 130 | }
|
88 | 131 |
|
| 132 | + // initialize batch |
| 133 | + const int n_prompts = prompts.size(); |
| 134 | + struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts); |
| 135 | + |
| 136 | + // allocate output |
89 | 137 | const int n_embd = llama_n_embd(model);
|
90 |
| - auto * embeddings = llama_get_embeddings(ctx); |
| 138 | + std::vector<float> embeddings(n_prompts * n_embd, 0); |
| 139 | + float * emb = embeddings.data(); |
| 140 | + |
| 141 | + // break into batches |
| 142 | + int p = 0; // number of prompts processed already |
| 143 | + int s = 0; // number of prompts in current batch |
| 144 | + for (int k = 0; k < n_prompts; k++) { |
| 145 | + // clamp to n_batch tokens |
| 146 | + auto & inp = inputs[k]; |
| 147 | + const uint64_t n_toks = inp.size(); |
| 148 | + |
| 149 | + // encode if at capacity |
| 150 | + if (batch.n_tokens + n_toks > n_batch) { |
| 151 | + float * out = emb + p * n_embd; |
| 152 | + batch_decode(ctx, batch, out, s, n_embd); |
| 153 | + llama_batch_clear(batch); |
| 154 | + p += s; |
| 155 | + s = 0; |
| 156 | + } |
91 | 157 |
|
92 |
| - // l2-normalize embeddings |
93 |
| - float norm = 0; |
94 |
| - for (int i = 0; i < n_embd; i++) { |
95 |
| - norm += embeddings[i] * embeddings[i]; |
96 |
| - } |
97 |
| - norm = sqrt(norm); |
98 |
| - for (int i = 0; i < n_embd; i++) { |
99 |
| - embeddings[i] /= norm; |
| 158 | + // add to batch |
| 159 | + batch_add_seq(batch, inp, s); |
| 160 | + s += 1; |
100 | 161 | }
|
101 | 162 |
|
102 |
| - for (int i = 0; i < n_embd; i++) { |
103 |
| - printf("%f ", embeddings[i]); |
| 163 | + // final batch |
| 164 | + float * out = emb + p * n_embd; |
| 165 | + batch_decode(ctx, batch, out, s, n_embd); |
| 166 | + |
| 167 | + // print first 3 embeddings |
| 168 | + for (int j = 0; j < std::min(3, n_prompts); j++) { |
| 169 | + fprintf(stderr, "embedding %d: ", j); |
| 170 | + for (int i = 0; i < n_embd; i++) { |
| 171 | + fprintf(stderr, "%f ", emb[j * n_embd + i]); |
| 172 | + } |
| 173 | + fprintf(stderr, "\n\n"); |
104 | 174 | }
|
105 |
| - printf("\n"); |
| 175 | + fprintf(stderr, "\n"); |
106 | 176 |
|
| 177 | + // clean up |
107 | 178 | llama_print_timings(ctx);
|
108 | 179 | llama_free(ctx);
|
109 | 180 | llama_free_model(model);
|
110 |
| - |
111 | 181 | llama_backend_free();
|
112 | 182 |
|
113 | 183 | return 0;
|
|
0 commit comments