Skip to content

Commit 007064f

Browse files
committed
llama : context
ggml-ci
1 parent 5bf9dc5 commit 007064f

File tree

3 files changed

+383
-358
lines changed

3 files changed

+383
-358
lines changed

src/llama-context.cpp

Lines changed: 360 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,367 @@
11
#include "llama-context.h"
22

3+
#include <cassert>
34
#include <cstring>
45
#include <stdexcept>
56

7+
void llama_set_k_shift(struct llama_context & lctx) {
8+
const int64_t kv_size = lctx.kv_self.size;
9+
10+
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
11+
12+
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
13+
14+
for (int i = 0; i < kv_size; ++i) {
15+
data[i] = lctx.kv_self.cells[i].delta;
16+
}
17+
}
18+
19+
void llama_set_s_copy(struct llama_context & lctx) {
20+
const int64_t kv_size = lctx.kv_self.size;
21+
22+
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
23+
24+
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
25+
26+
for (int i = 0; i < kv_size; ++i) {
27+
data[i] = lctx.kv_self.cells[i].src;
28+
}
29+
}
30+
31+
// llama output
32+
33+
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
34+
const auto & cparams = lctx.cparams;
35+
const auto & hparams = lctx.model.hparams;
36+
37+
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
38+
39+
const auto n_batch = cparams.n_batch;
40+
const auto n_vocab = hparams.n_vocab;
41+
const auto n_embd = hparams.n_embd;
42+
43+
// TODO: use a per-batch flag for logits presence instead
44+
const bool has_logits = !cparams.embeddings;
45+
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
46+
47+
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
48+
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
49+
50+
if (lctx.output_ids.empty()) {
51+
// init, never resized afterwards
52+
lctx.output_ids.resize(n_batch);
53+
}
54+
55+
const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0;
56+
const size_t new_size = (logits_size + embd_size) * sizeof(float);
57+
58+
// alloc only when more than the current capacity is required
59+
// TODO: also consider shrinking the buffer
60+
if (!lctx.buf_output || prev_size < new_size) {
61+
if (lctx.buf_output) {
62+
#ifndef NDEBUG
63+
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
64+
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
65+
#endif
66+
lctx.buf_output = nullptr;
67+
lctx.logits = nullptr;
68+
lctx.embd = nullptr;
69+
}
70+
71+
auto * buft = ggml_backend_cpu_buffer_type();
72+
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
73+
auto * output_dev = lctx.model.dev_output.dev;
74+
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
75+
if (output_dev_host_buft) {
76+
buft = output_dev_host_buft;
77+
}
78+
lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
79+
if (lctx.buf_output == nullptr) {
80+
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
81+
return 0;
82+
}
83+
}
84+
85+
float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get());
86+
87+
lctx.logits = has_logits ? output_base : nullptr;
88+
lctx.embd = has_embd ? output_base + logits_size : nullptr;
89+
90+
lctx.output_size = n_outputs_max;
91+
lctx.logits_size = logits_size;
92+
lctx.embd_size = embd_size;
93+
94+
// set all ids as invalid (negative)
95+
std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
96+
97+
ggml_backend_buffer_clear(lctx.buf_output.get(), 0);
98+
99+
lctx.n_outputs = 0;
100+
101+
return n_outputs_max;
102+
}
103+
104+
void llama_output_reorder(struct llama_context & ctx) {
105+
std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
106+
if (!out_ids.empty()) {
107+
const uint32_t n_vocab = ctx.model.hparams.n_vocab;
108+
const uint32_t n_embd = ctx.model.hparams.n_embd;
109+
110+
const int32_t n_outputs = ctx.n_outputs;
111+
GGML_ASSERT((size_t) n_outputs == out_ids.size());
112+
113+
// TODO: is there something more efficient which also minimizes swaps?
114+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
115+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
116+
int32_t j_min = i;
117+
for (int32_t j = i + 1; j < n_outputs; ++j) {
118+
if (out_ids[j] < out_ids[j_min]) {
119+
j_min = j;
120+
}
121+
}
122+
if (j_min == i) { continue; }
123+
std::swap(out_ids[i], out_ids[j_min]);
124+
if (ctx.logits_size > 0) {
125+
for (uint32_t k = 0; k < n_vocab; k++) {
126+
std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]);
127+
}
128+
}
129+
if (ctx.embd_size > 0) {
130+
for (uint32_t k = 0; k < n_embd; k++) {
131+
std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]);
132+
}
133+
}
134+
}
135+
std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1);
136+
for (int32_t i = 0; i < n_outputs; ++i) {
137+
ctx.output_ids[out_ids[i]] = i;
138+
}
139+
out_ids.clear();
140+
}
141+
}
142+
143+
//
144+
// interface implementation
145+
//
146+
147+
void llama_free(struct llama_context * ctx) {
148+
delete ctx;
149+
}
150+
151+
uint32_t llama_n_ctx(const struct llama_context * ctx) {
152+
return ctx->cparams.n_ctx;
153+
}
154+
155+
uint32_t llama_n_batch(const struct llama_context * ctx) {
156+
return ctx->cparams.n_batch;
157+
}
158+
159+
uint32_t llama_n_ubatch(const struct llama_context * ctx) {
160+
return ctx->cparams.n_ubatch;
161+
}
162+
163+
uint32_t llama_n_seq_max(const struct llama_context * ctx) {
164+
return ctx->kv_self.size;
165+
}
166+
167+
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
168+
return &ctx->model;
169+
}
170+
171+
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
172+
return ctx->cparams.pooling_type;
173+
}
174+
175+
void llama_attach_threadpool(
176+
struct llama_context * ctx,
177+
ggml_threadpool_t threadpool,
178+
ggml_threadpool_t threadpool_batch) {
179+
ctx->threadpool = threadpool;
180+
ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
181+
}
182+
183+
void llama_detach_threadpool(struct llama_context * ctx) {
184+
ctx->threadpool = nullptr;
185+
ctx->threadpool_batch = nullptr;
186+
}
187+
188+
void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
189+
ctx->cparams.n_threads = n_threads;
190+
ctx->cparams.n_threads_batch = n_threads_batch;
191+
}
192+
193+
int32_t llama_n_threads(struct llama_context * ctx) {
194+
return ctx->cparams.n_threads;
195+
}
196+
197+
int32_t llama_n_threads_batch(struct llama_context * ctx) {
198+
return ctx->cparams.n_threads_batch;
199+
}
200+
201+
void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
202+
ctx->abort_callback = abort_callback;
203+
ctx->abort_callback_data = abort_callback_data;
204+
205+
for (auto & backend : ctx->backends) {
206+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
207+
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
208+
if (set_abort_callback_fn) {
209+
set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data);
210+
}
211+
}
212+
}
213+
214+
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
215+
ctx->cparams.embeddings = embeddings;
216+
}
217+
218+
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
219+
ctx->cparams.causal_attn = causal_attn;
220+
}
221+
222+
void llama_synchronize(struct llama_context * ctx) {
223+
ggml_backend_sched_synchronize(ctx->sched.get());
224+
225+
// FIXME: if multiple single tokens are evaluated without a synchronization,
226+
// the stats will be added to the prompt evaluation stats
227+
// this should only happen when using batch size 1 to evaluate a batch
228+
229+
// add the evaluation to the stats
230+
if (ctx->n_queued_tokens == 1) {
231+
if (!ctx->cparams.no_perf) {
232+
ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
233+
}
234+
ctx->n_eval++;
235+
} else if (ctx->n_queued_tokens > 1) {
236+
if (!ctx->cparams.no_perf) {
237+
ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
238+
}
239+
ctx->n_p_eval += ctx->n_queued_tokens;
240+
}
241+
242+
// get a more accurate load time, upon first eval
243+
if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
244+
ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
245+
ctx->has_evaluated_once = true;
246+
}
247+
248+
ctx->n_queued_tokens = 0;
249+
ctx->t_compute_start_us = 0;
250+
}
251+
252+
float * llama_get_logits(struct llama_context * ctx) {
253+
llama_synchronize(ctx);
254+
255+
// reorder logits for backward compatibility
256+
// TODO: maybe deprecate this
257+
llama_output_reorder(*ctx);
258+
259+
return ctx->logits;
260+
}
261+
262+
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
263+
int32_t j = -1;
264+
265+
llama_synchronize(ctx);
266+
267+
try {
268+
if (ctx->logits == nullptr) {
269+
throw std::runtime_error("no logits");
270+
}
271+
272+
if (i < 0) {
273+
j = ctx->n_outputs + i;
274+
if (j < 0) {
275+
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
276+
}
277+
} else if ((size_t) i >= ctx->output_ids.size()) {
278+
throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
279+
} else {
280+
j = ctx->output_ids[i];
281+
}
282+
283+
if (j < 0) {
284+
throw std::runtime_error(format("batch.logits[%d] != true", i));
285+
}
286+
if (j >= ctx->n_outputs) {
287+
// This should not happen
288+
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
289+
}
290+
291+
return ctx->logits + j*ctx->model.hparams.n_vocab;
292+
} catch (const std::exception & err) {
293+
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
294+
#ifndef NDEBUG
295+
GGML_ABORT("fatal error");
296+
#else
297+
return nullptr;
298+
#endif
299+
}
300+
}
301+
302+
float * llama_get_embeddings(struct llama_context * ctx) {
303+
llama_synchronize(ctx);
304+
305+
// reorder embeddings for backward compatibility
306+
// TODO: maybe deprecate this
307+
llama_output_reorder(*ctx);
308+
309+
return ctx->embd;
310+
}
311+
312+
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
313+
int32_t j = -1;
314+
315+
llama_synchronize(ctx);
316+
317+
try {
318+
if (ctx->embd == nullptr) {
319+
throw std::runtime_error("no embeddings");
320+
}
321+
322+
if (i < 0) {
323+
j = ctx->n_outputs + i;
324+
if (j < 0) {
325+
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
326+
}
327+
} else if ((size_t) i >= ctx->output_ids.size()) {
328+
throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
329+
} else {
330+
j = ctx->output_ids[i];
331+
}
332+
333+
if (j < 0) {
334+
throw std::runtime_error(format("batch.logits[%d] != true", i));
335+
}
336+
if (j >= ctx->n_outputs) {
337+
// This should not happen
338+
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
339+
}
340+
341+
return ctx->embd + j*ctx->model.hparams.n_embd;
342+
} catch (const std::exception & err) {
343+
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
344+
#ifndef NDEBUG
345+
GGML_ABORT("fatal error");
346+
#else
347+
return nullptr;
348+
#endif
349+
}
350+
}
351+
352+
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
353+
llama_synchronize(ctx);
354+
355+
auto it = ctx->embd_seq.find(seq_id);
356+
if (it == ctx->embd_seq.end()) {
357+
return nullptr;
358+
}
359+
360+
return it->second.data();
361+
}
362+
363+
// llama state API
364+
6365
// deprecated
7366
size_t llama_get_state_size(struct llama_context * ctx) {
8367
return llama_state_get_size(ctx);
@@ -58,7 +417,7 @@ struct llama_data_write {
58417
//}
59418

60419
void write_output_ids(struct llama_context * ctx) {
61-
llama_output_reorder(ctx);
420+
llama_output_reorder(*ctx);
62421

63422
const uint32_t n_outputs = ctx->n_outputs;
64423

0 commit comments

Comments
 (0)