|
1 | 1 | #include "llama-context.h"
|
2 | 2 |
|
| 3 | +#include <cassert> |
3 | 4 | #include <cstring>
|
4 | 5 | #include <stdexcept>
|
5 | 6 |
|
| 7 | +void llama_set_k_shift(struct llama_context & lctx) { |
| 8 | + const int64_t kv_size = lctx.kv_self.size; |
| 9 | + |
| 10 | + assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); |
| 11 | + |
| 12 | + int32_t * data = (int32_t *) lctx.inp_K_shift->data; |
| 13 | + |
| 14 | + for (int i = 0; i < kv_size; ++i) { |
| 15 | + data[i] = lctx.kv_self.cells[i].delta; |
| 16 | + } |
| 17 | +} |
| 18 | + |
| 19 | +void llama_set_s_copy(struct llama_context & lctx) { |
| 20 | + const int64_t kv_size = lctx.kv_self.size; |
| 21 | + |
| 22 | + assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); |
| 23 | + |
| 24 | + int32_t * data = (int32_t *) lctx.inp_s_copy->data; |
| 25 | + |
| 26 | + for (int i = 0; i < kv_size; ++i) { |
| 27 | + data[i] = lctx.kv_self.cells[i].src; |
| 28 | + } |
| 29 | +} |
| 30 | + |
| 31 | +// llama output |
| 32 | + |
| 33 | +size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { |
| 34 | + const auto & cparams = lctx.cparams; |
| 35 | + const auto & hparams = lctx.model.hparams; |
| 36 | + |
| 37 | + const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max); |
| 38 | + |
| 39 | + const auto n_batch = cparams.n_batch; |
| 40 | + const auto n_vocab = hparams.n_vocab; |
| 41 | + const auto n_embd = hparams.n_embd; |
| 42 | + |
| 43 | + // TODO: use a per-batch flag for logits presence instead |
| 44 | + const bool has_logits = !cparams.embeddings; |
| 45 | + const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); |
| 46 | + |
| 47 | + const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; |
| 48 | + const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; |
| 49 | + |
| 50 | + if (lctx.output_ids.empty()) { |
| 51 | + // init, never resized afterwards |
| 52 | + lctx.output_ids.resize(n_batch); |
| 53 | + } |
| 54 | + |
| 55 | + const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0; |
| 56 | + const size_t new_size = (logits_size + embd_size) * sizeof(float); |
| 57 | + |
| 58 | + // alloc only when more than the current capacity is required |
| 59 | + // TODO: also consider shrinking the buffer |
| 60 | + if (!lctx.buf_output || prev_size < new_size) { |
| 61 | + if (lctx.buf_output) { |
| 62 | +#ifndef NDEBUG |
| 63 | + // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) |
| 64 | + LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); |
| 65 | +#endif |
| 66 | + lctx.buf_output = nullptr; |
| 67 | + lctx.logits = nullptr; |
| 68 | + lctx.embd = nullptr; |
| 69 | + } |
| 70 | + |
| 71 | + auto * buft = ggml_backend_cpu_buffer_type(); |
| 72 | + // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory |
| 73 | + auto * output_dev = lctx.model.dev_output.dev; |
| 74 | + auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; |
| 75 | + if (output_dev_host_buft) { |
| 76 | + buft = output_dev_host_buft; |
| 77 | + } |
| 78 | + lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); |
| 79 | + if (lctx.buf_output == nullptr) { |
| 80 | + LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); |
| 81 | + return 0; |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get()); |
| 86 | + |
| 87 | + lctx.logits = has_logits ? output_base : nullptr; |
| 88 | + lctx.embd = has_embd ? output_base + logits_size : nullptr; |
| 89 | + |
| 90 | + lctx.output_size = n_outputs_max; |
| 91 | + lctx.logits_size = logits_size; |
| 92 | + lctx.embd_size = embd_size; |
| 93 | + |
| 94 | + // set all ids as invalid (negative) |
| 95 | + std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1); |
| 96 | + |
| 97 | + ggml_backend_buffer_clear(lctx.buf_output.get(), 0); |
| 98 | + |
| 99 | + lctx.n_outputs = 0; |
| 100 | + |
| 101 | + return n_outputs_max; |
| 102 | +} |
| 103 | + |
| 104 | +void llama_output_reorder(struct llama_context & ctx) { |
| 105 | + std::vector<size_t> & out_ids = ctx.sbatch.out_ids; |
| 106 | + if (!out_ids.empty()) { |
| 107 | + const uint32_t n_vocab = ctx.model.hparams.n_vocab; |
| 108 | + const uint32_t n_embd = ctx.model.hparams.n_embd; |
| 109 | + |
| 110 | + const int32_t n_outputs = ctx.n_outputs; |
| 111 | + GGML_ASSERT((size_t) n_outputs == out_ids.size()); |
| 112 | + |
| 113 | + // TODO: is there something more efficient which also minimizes swaps? |
| 114 | + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) |
| 115 | + for (int32_t i = 0; i < n_outputs - 1; ++i) { |
| 116 | + int32_t j_min = i; |
| 117 | + for (int32_t j = i + 1; j < n_outputs; ++j) { |
| 118 | + if (out_ids[j] < out_ids[j_min]) { |
| 119 | + j_min = j; |
| 120 | + } |
| 121 | + } |
| 122 | + if (j_min == i) { continue; } |
| 123 | + std::swap(out_ids[i], out_ids[j_min]); |
| 124 | + if (ctx.logits_size > 0) { |
| 125 | + for (uint32_t k = 0; k < n_vocab; k++) { |
| 126 | + std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]); |
| 127 | + } |
| 128 | + } |
| 129 | + if (ctx.embd_size > 0) { |
| 130 | + for (uint32_t k = 0; k < n_embd; k++) { |
| 131 | + std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]); |
| 132 | + } |
| 133 | + } |
| 134 | + } |
| 135 | + std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1); |
| 136 | + for (int32_t i = 0; i < n_outputs; ++i) { |
| 137 | + ctx.output_ids[out_ids[i]] = i; |
| 138 | + } |
| 139 | + out_ids.clear(); |
| 140 | + } |
| 141 | +} |
| 142 | + |
| 143 | +// |
| 144 | +// interface implementation |
| 145 | +// |
| 146 | + |
| 147 | +void llama_free(struct llama_context * ctx) { |
| 148 | + delete ctx; |
| 149 | +} |
| 150 | + |
| 151 | +uint32_t llama_n_ctx(const struct llama_context * ctx) { |
| 152 | + return ctx->cparams.n_ctx; |
| 153 | +} |
| 154 | + |
| 155 | +uint32_t llama_n_batch(const struct llama_context * ctx) { |
| 156 | + return ctx->cparams.n_batch; |
| 157 | +} |
| 158 | + |
| 159 | +uint32_t llama_n_ubatch(const struct llama_context * ctx) { |
| 160 | + return ctx->cparams.n_ubatch; |
| 161 | +} |
| 162 | + |
| 163 | +uint32_t llama_n_seq_max(const struct llama_context * ctx) { |
| 164 | + return ctx->kv_self.size; |
| 165 | +} |
| 166 | + |
| 167 | +const struct llama_model * llama_get_model(const struct llama_context * ctx) { |
| 168 | + return &ctx->model; |
| 169 | +} |
| 170 | + |
| 171 | +enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { |
| 172 | + return ctx->cparams.pooling_type; |
| 173 | +} |
| 174 | + |
| 175 | +void llama_attach_threadpool( |
| 176 | + struct llama_context * ctx, |
| 177 | + ggml_threadpool_t threadpool, |
| 178 | + ggml_threadpool_t threadpool_batch) { |
| 179 | + ctx->threadpool = threadpool; |
| 180 | + ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; |
| 181 | +} |
| 182 | + |
| 183 | +void llama_detach_threadpool(struct llama_context * ctx) { |
| 184 | + ctx->threadpool = nullptr; |
| 185 | + ctx->threadpool_batch = nullptr; |
| 186 | +} |
| 187 | + |
| 188 | +void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { |
| 189 | + ctx->cparams.n_threads = n_threads; |
| 190 | + ctx->cparams.n_threads_batch = n_threads_batch; |
| 191 | +} |
| 192 | + |
| 193 | +int32_t llama_n_threads(struct llama_context * ctx) { |
| 194 | + return ctx->cparams.n_threads; |
| 195 | +} |
| 196 | + |
| 197 | +int32_t llama_n_threads_batch(struct llama_context * ctx) { |
| 198 | + return ctx->cparams.n_threads_batch; |
| 199 | +} |
| 200 | + |
| 201 | +void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { |
| 202 | + ctx->abort_callback = abort_callback; |
| 203 | + ctx->abort_callback_data = abort_callback_data; |
| 204 | + |
| 205 | + for (auto & backend : ctx->backends) { |
| 206 | + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); |
| 207 | + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); |
| 208 | + if (set_abort_callback_fn) { |
| 209 | + set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data); |
| 210 | + } |
| 211 | + } |
| 212 | +} |
| 213 | + |
| 214 | +void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { |
| 215 | + ctx->cparams.embeddings = embeddings; |
| 216 | +} |
| 217 | + |
| 218 | +void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { |
| 219 | + ctx->cparams.causal_attn = causal_attn; |
| 220 | +} |
| 221 | + |
| 222 | +void llama_synchronize(struct llama_context * ctx) { |
| 223 | + ggml_backend_sched_synchronize(ctx->sched.get()); |
| 224 | + |
| 225 | + // FIXME: if multiple single tokens are evaluated without a synchronization, |
| 226 | + // the stats will be added to the prompt evaluation stats |
| 227 | + // this should only happen when using batch size 1 to evaluate a batch |
| 228 | + |
| 229 | + // add the evaluation to the stats |
| 230 | + if (ctx->n_queued_tokens == 1) { |
| 231 | + if (!ctx->cparams.no_perf) { |
| 232 | + ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us; |
| 233 | + } |
| 234 | + ctx->n_eval++; |
| 235 | + } else if (ctx->n_queued_tokens > 1) { |
| 236 | + if (!ctx->cparams.no_perf) { |
| 237 | + ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us; |
| 238 | + } |
| 239 | + ctx->n_p_eval += ctx->n_queued_tokens; |
| 240 | + } |
| 241 | + |
| 242 | + // get a more accurate load time, upon first eval |
| 243 | + if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) { |
| 244 | + ctx->t_load_us = ggml_time_us() - ctx->t_start_us; |
| 245 | + ctx->has_evaluated_once = true; |
| 246 | + } |
| 247 | + |
| 248 | + ctx->n_queued_tokens = 0; |
| 249 | + ctx->t_compute_start_us = 0; |
| 250 | +} |
| 251 | + |
| 252 | +float * llama_get_logits(struct llama_context * ctx) { |
| 253 | + llama_synchronize(ctx); |
| 254 | + |
| 255 | + // reorder logits for backward compatibility |
| 256 | + // TODO: maybe deprecate this |
| 257 | + llama_output_reorder(*ctx); |
| 258 | + |
| 259 | + return ctx->logits; |
| 260 | +} |
| 261 | + |
| 262 | +float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { |
| 263 | + int32_t j = -1; |
| 264 | + |
| 265 | + llama_synchronize(ctx); |
| 266 | + |
| 267 | + try { |
| 268 | + if (ctx->logits == nullptr) { |
| 269 | + throw std::runtime_error("no logits"); |
| 270 | + } |
| 271 | + |
| 272 | + if (i < 0) { |
| 273 | + j = ctx->n_outputs + i; |
| 274 | + if (j < 0) { |
| 275 | + throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); |
| 276 | + } |
| 277 | + } else if ((size_t) i >= ctx->output_ids.size()) { |
| 278 | + throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); |
| 279 | + } else { |
| 280 | + j = ctx->output_ids[i]; |
| 281 | + } |
| 282 | + |
| 283 | + if (j < 0) { |
| 284 | + throw std::runtime_error(format("batch.logits[%d] != true", i)); |
| 285 | + } |
| 286 | + if (j >= ctx->n_outputs) { |
| 287 | + // This should not happen |
| 288 | + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); |
| 289 | + } |
| 290 | + |
| 291 | + return ctx->logits + j*ctx->model.hparams.n_vocab; |
| 292 | + } catch (const std::exception & err) { |
| 293 | + LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); |
| 294 | +#ifndef NDEBUG |
| 295 | + GGML_ABORT("fatal error"); |
| 296 | +#else |
| 297 | + return nullptr; |
| 298 | +#endif |
| 299 | + } |
| 300 | +} |
| 301 | + |
| 302 | +float * llama_get_embeddings(struct llama_context * ctx) { |
| 303 | + llama_synchronize(ctx); |
| 304 | + |
| 305 | + // reorder embeddings for backward compatibility |
| 306 | + // TODO: maybe deprecate this |
| 307 | + llama_output_reorder(*ctx); |
| 308 | + |
| 309 | + return ctx->embd; |
| 310 | +} |
| 311 | + |
| 312 | +float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { |
| 313 | + int32_t j = -1; |
| 314 | + |
| 315 | + llama_synchronize(ctx); |
| 316 | + |
| 317 | + try { |
| 318 | + if (ctx->embd == nullptr) { |
| 319 | + throw std::runtime_error("no embeddings"); |
| 320 | + } |
| 321 | + |
| 322 | + if (i < 0) { |
| 323 | + j = ctx->n_outputs + i; |
| 324 | + if (j < 0) { |
| 325 | + throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); |
| 326 | + } |
| 327 | + } else if ((size_t) i >= ctx->output_ids.size()) { |
| 328 | + throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); |
| 329 | + } else { |
| 330 | + j = ctx->output_ids[i]; |
| 331 | + } |
| 332 | + |
| 333 | + if (j < 0) { |
| 334 | + throw std::runtime_error(format("batch.logits[%d] != true", i)); |
| 335 | + } |
| 336 | + if (j >= ctx->n_outputs) { |
| 337 | + // This should not happen |
| 338 | + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); |
| 339 | + } |
| 340 | + |
| 341 | + return ctx->embd + j*ctx->model.hparams.n_embd; |
| 342 | + } catch (const std::exception & err) { |
| 343 | + LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); |
| 344 | +#ifndef NDEBUG |
| 345 | + GGML_ABORT("fatal error"); |
| 346 | +#else |
| 347 | + return nullptr; |
| 348 | +#endif |
| 349 | + } |
| 350 | +} |
| 351 | + |
| 352 | +float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { |
| 353 | + llama_synchronize(ctx); |
| 354 | + |
| 355 | + auto it = ctx->embd_seq.find(seq_id); |
| 356 | + if (it == ctx->embd_seq.end()) { |
| 357 | + return nullptr; |
| 358 | + } |
| 359 | + |
| 360 | + return it->second.data(); |
| 361 | +} |
| 362 | + |
| 363 | +// llama state API |
| 364 | + |
6 | 365 | // deprecated
|
7 | 366 | size_t llama_get_state_size(struct llama_context * ctx) {
|
8 | 367 | return llama_state_get_size(ctx);
|
@@ -58,7 +417,7 @@ struct llama_data_write {
|
58 | 417 | //}
|
59 | 418 |
|
60 | 419 | void write_output_ids(struct llama_context * ctx) {
|
61 |
| - llama_output_reorder(ctx); |
| 420 | + llama_output_reorder(*ctx); |
62 | 421 |
|
63 | 422 | const uint32_t n_outputs = ctx->n_outputs;
|
64 | 423 |
|
|
0 commit comments