Skip to content

Commit abf817a

Browse files
committed
cont : remove llama_sampling_context
ggml-ci
1 parent 164f9d7 commit abf817a

File tree

13 files changed

+146
-164
lines changed

13 files changed

+146
-164
lines changed

common/sampling.cpp

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ std::string gpt_sampling_params::print_samplers() const {
3131

3232
return result;
3333
}
34-
struct llama_sampling_context * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) {
35-
struct llama_sampling_context * result = new llama_sampling_context();
36-
37-
result->params = params;
34+
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) {
35+
struct llama_sampling * result = nullptr;
3836

3937
{
4038
auto lparams = llama_sampling_default_params();
@@ -66,35 +64,25 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_model * m
6664
lparams.samplers[i] = params.samplers[i];
6765
}
6866

69-
result->smpl = llama_sampling_init(model, lparams);
67+
result = llama_sampling_init(model, lparams);
7068

71-
llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root");
72-
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
69+
llama_sampling_set_grammar (result, params.grammar.c_str(), "root");
70+
llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data());
7371
}
7472

7573
return result;
7674
}
7775

78-
void llama_sampling_free(struct llama_sampling_context * ctx) {
79-
llama_sampling_free(ctx->smpl);
80-
81-
delete ctx;
82-
}
83-
84-
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
85-
if (dst->smpl) {
86-
llama_sampling_free(dst->smpl);
76+
void llama_sampling_cp(llama_sampling * src, llama_sampling * dst) {
77+
if (dst) {
78+
llama_sampling_free(dst);
8779
}
8880

89-
dst->smpl = llama_sampling_cp(src->smpl);
90-
}
91-
92-
llama_token llama_sampling_last(llama_sampling_context * ctx) {
93-
return llama_sampling_prev(ctx->smpl, 0);
81+
dst = llama_sampling_cp(src);
9482
}
9583

96-
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
97-
n = std::min(n, llama_sampling_n_prev(ctx_sampling->smpl));
84+
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_main, int n) {
85+
n = std::min(n, llama_sampling_n_prev(smpl));
9886

9987
if (n <= 0) {
10088
return "";
@@ -104,7 +92,7 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
10492
result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
10593

10694
for (int i = n - 1; i >= 0; i--) {
107-
const llama_token id = llama_sampling_prev(ctx_sampling->smpl, i);
95+
const llama_token id = llama_sampling_prev(smpl, i);
10896

10997
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
11098

@@ -206,14 +194,14 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
206194
}
207195

208196
llama_token llama_sampling_sample(
209-
struct llama_sampling_context * ctx_sampling,
210-
struct llama_context * ctx_main,
197+
struct llama_sampling * smpl,
198+
struct llama_context * ctx,
211199
int idx) {
212-
llama_sampling_set_logits(ctx_sampling->smpl, llama_get_logits_ith(ctx_main, idx));
200+
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
213201

214-
auto * cur_p = llama_sampling_get_candidates(ctx_sampling->smpl);
202+
auto * cur_p = llama_sampling_get_candidates(smpl);
215203

216-
llama_sampling_grammar(ctx_sampling->smpl, cur_p);
204+
llama_sampling_grammar(smpl, cur_p);
217205

218-
return llama_sampling_sample(ctx_sampling->smpl, cur_p);
206+
return llama_sampling_sample(smpl, cur_p);
219207
}

common/sampling.h

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
// sampling parameters
99
typedef struct gpt_sampling_params {
10-
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
10+
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling
1111

1212
int32_t n_prev = 64; // number of previous tokens to remember
1313
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
@@ -30,7 +30,7 @@ typedef struct gpt_sampling_params {
3030
bool penalize_nl = false; // consider newlines as a repeatable token
3131
bool ignore_eos = false;
3232

33-
std::vector<llama_sampler_type> samplers = {
33+
std::vector<enum llama_sampler_type> samplers = {
3434
LLAMA_SAMPLER_TYPE_TOP_K,
3535
LLAMA_SAMPLER_TYPE_TFS_Z,
3636
LLAMA_SAMPLER_TYPE_TYPICAL_P,
@@ -50,36 +50,21 @@ typedef struct gpt_sampling_params {
5050
std::string print_samplers() const;
5151
} gpt_sampling_params;
5252

53-
// general sampler context
54-
// TODO: move to llama.h
55-
struct llama_sampling_context {
56-
// parameters that will be used for sampling
57-
gpt_sampling_params params;
53+
// overload of llama_sampling_init using gpt_sampling_params
54+
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params);
5855

59-
llama_sampling * smpl;
60-
};
56+
void llama_sampling_cp(llama_sampling * src, llama_sampling * dst);
6157

62-
// Create a new sampling context instance.
63-
struct llama_sampling_context * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params);
58+
// get a string representation of the last accepted tokens
59+
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n);
6460

65-
void llama_sampling_free(struct llama_sampling_context * ctx);
61+
char llama_sampling_type_to_chr(enum llama_sampler_type sampler_type);
62+
std::string llama_sampling_type_to_str(enum llama_sampler_type sampler_type);
6663

67-
// Copy the sampler context
68-
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
69-
70-
// Get the last accepted token
71-
llama_token llama_sampling_last(llama_sampling_context * ctx);
72-
73-
// Get a string representation of the last accepted tokens
74-
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
75-
76-
char llama_sampling_type_to_chr(llama_sampler_type sampler_type);
77-
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
78-
79-
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
80-
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
64+
std::vector<enum llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
65+
std::vector<enum llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
8166

8267
llama_token llama_sampling_sample(
83-
struct llama_sampling_context * ctx_sampling,
84-
struct llama_context * ctx_main,
68+
struct llama_sampling * smpl,
69+
struct llama_context * ctx,
8570
int idx = -1);

examples/infill/infill.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
static llama_context ** g_ctx;
3535
static llama_model ** g_model;
36-
static llama_sampling_context ** g_ctx_sampling;
36+
static llama_sampling ** g_smpl;
3737
static gpt_params * g_params;
3838
static std::vector<llama_token> * g_input_tokens;
3939
static std::ostringstream * g_output_ss;
@@ -93,7 +93,7 @@ static void sigint_handler(int signo) {
9393
} else {
9494
console::cleanup();
9595
printf("\n");
96-
llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl);
96+
llama_print_timings(*g_ctx, *g_smpl);
9797
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
9898
_exit(130);
9999
}
@@ -167,11 +167,11 @@ int main(int argc, char ** argv) {
167167

168168
llama_model * model = nullptr;
169169
llama_context * ctx = nullptr;
170-
llama_sampling_context * ctx_sampling = nullptr;
170+
llama_sampling * smpl = nullptr;
171171

172172
g_model = &model;
173173
g_ctx = &ctx;
174-
g_ctx_sampling = &ctx_sampling;
174+
g_smpl = &smpl;
175175

176176
// load the model and apply lora adapter, if any
177177
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
@@ -345,7 +345,7 @@ int main(int argc, char ** argv) {
345345

346346
std::vector<llama_token> embd;
347347

348-
ctx_sampling = llama_sampling_init(model, sparams);
348+
smpl = llama_sampling_init(model, sparams);
349349

350350
while (n_remain != 0 || params.interactive) {
351351
// predict
@@ -417,11 +417,11 @@ int main(int argc, char ** argv) {
417417
embd.clear();
418418

419419
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
420-
const llama_token id = llama_sampling_sample(ctx_sampling, ctx);
420+
const llama_token id = llama_sampling_sample(smpl, ctx);
421421

422-
llama_sampling_accept(ctx_sampling->smpl, id, true);
422+
llama_sampling_accept(smpl, id, true);
423423

424-
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
424+
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
425425

426426
embd.push_back(id);
427427

@@ -440,7 +440,7 @@ int main(int argc, char ** argv) {
440440

441441
// push the prompt in the sampling context in order to apply repetition penalties later
442442
// for the prompt, we don't apply grammar rules
443-
llama_sampling_accept(ctx_sampling->smpl, embd_inp[n_consumed], false);
443+
llama_sampling_accept(smpl, embd_inp[n_consumed], false);
444444

445445
++n_consumed;
446446
if ((int) embd.size() >= params.n_batch) {
@@ -472,7 +472,7 @@ int main(int argc, char ** argv) {
472472
// if not currently processing queued inputs;
473473
if ((int) embd_inp.size() <= n_consumed) {
474474
// deal with eot token in infill mode
475-
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
475+
if ((llama_sampling_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
476476
if (is_interacting && !params.interactive_first) {
477477
// print an eot token
478478
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
@@ -538,7 +538,7 @@ int main(int argc, char ** argv) {
538538
is_interacting = false;
539539
}
540540
// deal with end of generation tokens in interactive mode
541-
else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
541+
else if (llama_token_is_eog(model, llama_sampling_last(smpl))) {
542542
LOG("found EOS token\n");
543543

544544
if (params.interactive) {
@@ -611,7 +611,7 @@ int main(int argc, char ** argv) {
611611

612612
if (n_past > 0) {
613613
if (is_interacting) {
614-
llama_sampling_reset(ctx_sampling->smpl);
614+
llama_sampling_reset(smpl);
615615
}
616616
is_interacting = false;
617617
}
@@ -634,13 +634,13 @@ int main(int argc, char ** argv) {
634634
fflush(stdout);
635635
}
636636

637-
llama_print_timings(ctx, ctx_sampling->smpl);
637+
llama_print_timings(ctx, smpl);
638638
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
639639

640640
llama_free(ctx);
641641
llama_free_model(model);
642642

643-
llama_sampling_free(ctx_sampling);
643+
llama_sampling_free(smpl);
644644
llama_backend_free();
645645

646646
#ifndef LOG_DISABLE_LOGS

examples/llava/llava-cli.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
4040
return true;
4141
}
4242

43-
static const char * sample(struct llama_sampling_context * ctx_sampling,
43+
static const char * sample(struct llama_sampling * smpl,
4444
struct llama_context * ctx_llama,
4545
int * n_past) {
46-
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama);
47-
llama_sampling_accept(ctx_sampling->smpl, id, true);
46+
const llama_token id = llama_sampling_sample(smpl, ctx_llama);
47+
llama_sampling_accept(smpl, id, true);
4848
static std::string ret;
4949
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
5050
ret = "</s>";
@@ -191,15 +191,15 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
191191

192192
LOG_TEE("\n");
193193

194-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(ctx_llava->model, params->sparams);
195-
if (!ctx_sampling) {
194+
struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams);
195+
if (!smpl) {
196196
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
197197
exit(1);
198198
}
199199

200200
std::string response = "";
201201
for (int i = 0; i < max_tgt_len; i++) {
202-
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
202+
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
203203
response += tmp;
204204
if (strcmp(tmp, "</s>") == 0) break;
205205
if (strstr(tmp, "###")) break; // Yi-VL behavior
@@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
211211
fflush(stdout);
212212
}
213213

214-
llama_sampling_free(ctx_sampling);
214+
llama_sampling_free(smpl);
215215
printf("\n");
216216
}
217217

examples/llava/minicpmv-cli.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
163163
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
164164
}
165165

166-
static const char * sample(struct llama_sampling_context * ctx_sampling,
166+
static const char * sample(struct llama_sampling * smpl,
167167
struct llama_context * ctx_llama,
168168
int * n_past) {
169-
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama);
170-
llama_sampling_accept(ctx_sampling->smpl, id, true);
169+
const llama_token id = llama_sampling_sample(smpl, ctx_llama);
170+
llama_sampling_accept(smpl, id, true);
171171
static std::string ret;
172172
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
173173
ret = "</s>";
@@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
214214
return ctx_llava;
215215
}
216216

217-
static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
217+
static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
218218
std::string user_prompt = prompt;
219219
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
220220
if (!is_first) {
@@ -238,13 +238,13 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla
238238

239239
LOG_TEE("\n");
240240

241-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(ctx_llava->model, params->sparams);
242-
return ctx_sampling;
241+
struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams);
242+
return smpl;
243243
}
244244

245-
static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){
245+
static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling * smpl, int &n_past){
246246

247-
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
247+
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
248248
return tmp;
249249
}
250250

@@ -278,12 +278,12 @@ int main(int argc, char ** argv) {
278278
if (!params.prompt.empty()) {
279279
LOG_TEE("<user>%s\n", params.prompt.c_str());
280280
LOG_TEE("<assistant>");
281-
auto ctx_sampling = llama_init(ctx_llava, &params, params.prompt.c_str(), n_past, true);
281+
auto smpl = llama_init(ctx_llava, &params, params.prompt.c_str(), n_past, true);
282282
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
283283
std::string response = "";
284284
bool have_tmp = false;
285285
for (int i = 0; i < max_tgt_len; i++) {
286-
auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
286+
auto tmp = llama_loop(ctx_llava, smpl, n_past);
287287
response += tmp;
288288
if (strcmp(tmp, "</s>") == 0){
289289
if(!have_tmp)continue;
@@ -296,26 +296,26 @@ int main(int argc, char ** argv) {
296296

297297
fflush(stdout);
298298
}
299-
llama_sampling_free(ctx_sampling);
299+
llama_sampling_free(smpl);
300300
}else {
301301
while (true) {
302302
LOG_TEE("<user>");
303303
std::string prompt;
304304
std::getline(std::cin, prompt);
305305
LOG_TEE("<assistant>");
306-
auto ctx_sampling = llama_init(ctx_llava, &params, prompt, n_past, true);
306+
auto smpl = llama_init(ctx_llava, &params, prompt, n_past, true);
307307
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
308308
std::string response = "";
309309
for (int i = 0; i < max_tgt_len; i++) {
310-
auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
310+
auto tmp = llama_loop(ctx_llava, smpl, n_past);
311311
response += tmp;
312312
if (strcmp(tmp, "</s>") == 0) break;
313313
if (strstr(tmp, "###")) break; // Yi-VL behavior
314314
printf("%s", tmp);// mistral llava-1.6
315315
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
316316
fflush(stdout);
317317
}
318-
llama_sampling_free(ctx_sampling);
318+
llama_sampling_free(smpl);
319319
}
320320
}
321321
printf("\n");

0 commit comments

Comments
 (0)