Skip to content

Commit b91937c

Browse files
committed
sampling : option to use internal set of candidates
ggml-ci
1 parent abf817a commit b91937c

File tree

8 files changed

+82
-25
lines changed

8 files changed

+82
-25
lines changed

common/sampling.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,25 @@ llama_token llama_sampling_sample(
199199
int idx) {
200200
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
201201

202-
auto * cur_p = llama_sampling_get_candidates(smpl);
202+
// first, sample the token without any grammar constraints
203+
auto id = llama_sampling_sample(smpl, nullptr);
203204

204-
llama_sampling_grammar(smpl, cur_p);
205+
// create an array with a single token data element for the sampled id
206+
llama_token_data single_token_data = {id, 1.0f, 0.0f};
207+
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
205208

206-
return llama_sampling_sample(smpl, cur_p);
209+
llama_sampling_grammar(smpl, &single_token_data_array);
210+
211+
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
212+
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
213+
if (is_valid) {
214+
return id;
215+
}
216+
217+
// if the token is not valid, sample again, after applying the grammar constraints
218+
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
219+
220+
llama_sampling_grammar(smpl, nullptr);
221+
222+
return llama_sampling_sample(smpl, nullptr);
207223
}

common/sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,4 @@ std::vector<enum llama_sampler_type> llama_sampling_types_from_chars(const std::
6767
llama_token llama_sampling_sample(
6868
struct llama_sampling * smpl,
6969
struct llama_context * ctx,
70-
int idx = -1);
70+
int idx);

examples/batched.swift/Sources/main.swift

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -136,28 +136,17 @@ while n_cur <= n_len {
136136
continue
137137
}
138138

139-
var n_vocab = llama_n_vocab(model)
140139
var logits = llama_get_logits_ith(context, i_batch[i])
141140

142-
var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
141+
llama_sampling_set_logits(smpl, logits)
143142

144-
for token_id in 0 ..< n_vocab {
145-
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
146-
}
147-
148-
var candidates_p: llama_token_data_array = .init(
149-
data: &candidates,
150-
size: candidates.count,
151-
sorted: false
152-
)
153-
154-
llama_sampling_top_k(smpl, &candidates_p)
155-
llama_sampling_top_p(smpl, &candidates_p)
156-
llama_sampling_temp (smpl, &candidates_p)
143+
llama_sampling_top_k(smpl, nil)
144+
llama_sampling_top_p(smpl, nil)
145+
llama_sampling_temp (smpl, nil)
157146

158-
let new_token_id = llama_sampling_sample_dist(smpl, &candidates_p)
147+
let new_token_id = llama_sampling_sample_dist(smpl, nil)
159148

160-
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
149+
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil);
161150

162151
// is it an end of stream? -> mark the stream as finished
163152
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {

examples/infill/infill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ int main(int argc, char ** argv) {
417417
embd.clear();
418418

419419
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
420-
const llama_token id = llama_sampling_sample(smpl, ctx);
420+
const llama_token id = llama_sampling_sample(smpl, ctx, -1);
421421

422422
llama_sampling_accept(smpl, id, true);
423423

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
4343
static const char * sample(struct llama_sampling * smpl,
4444
struct llama_context * ctx_llama,
4545
int * n_past) {
46-
const llama_token id = llama_sampling_sample(smpl, ctx_llama);
46+
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
4747
llama_sampling_accept(smpl, id, true);
4848
static std::string ret;
4949
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {

examples/llava/minicpmv-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
166166
static const char * sample(struct llama_sampling * smpl,
167167
struct llama_context * ctx_llama,
168168
int * n_past) {
169-
const llama_token id = llama_sampling_sample(smpl, ctx_llama);
169+
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
170170
llama_sampling_accept(smpl, id, true);
171171
static std::string ret;
172172
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ int main(int argc, char ** argv) {
684684
LOG("saved session to %s\n", path_session.c_str());
685685
}
686686

687-
const llama_token id = llama_sampling_sample(smpl, ctx);
687+
const llama_token id = llama_sampling_sample(smpl, ctx, -1);
688688

689689
llama_sampling_accept(smpl, id, /* apply_grammar= */ true);
690690

src/llama.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20161,42 +20161,70 @@ llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * s
2016120161
void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2016220162
time_meas tm(smpl->t_sample_us);
2016320163

20164+
if (candidates == nullptr) {
20165+
candidates = &smpl->cur_p;
20166+
}
20167+
2016420168
llama_sampling_softmax_impl(candidates);
2016520169
}
2016620170

2016720171
void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2016820172
time_meas tm(smpl->t_sample_us);
2016920173

20174+
if (candidates == nullptr) {
20175+
candidates = &smpl->cur_p;
20176+
}
20177+
2017020178
llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep);
2017120179
}
2017220180

2017320181
void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2017420182
time_meas tm(smpl->t_sample_us);
2017520183

20184+
if (candidates == nullptr) {
20185+
candidates = &smpl->cur_p;
20186+
}
20187+
2017620188
llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep);
2017720189
}
2017820190

2017920191
void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2018020192
time_meas tm(smpl->t_sample_us);
2018120193

20194+
if (candidates == nullptr) {
20195+
candidates = &smpl->cur_p;
20196+
}
20197+
2018220198
llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep);
2018320199
}
2018420200

2018520201
void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2018620202
time_meas tm(smpl->t_sample_us);
2018720203

20204+
if (candidates == nullptr) {
20205+
candidates = &smpl->cur_p;
20206+
}
20207+
2018820208
llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep);
2018920209
}
2019020210

2019120211
void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2019220212
time_meas tm(smpl->t_sample_us);
2019320213

20214+
if (candidates == nullptr) {
20215+
candidates = &smpl->cur_p;
20216+
}
20217+
2019420218
llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep);
2019520219
}
2019620220

2019720221
void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2019820222
time_meas tm(smpl->t_sample_us);
2019920223

20224+
if (candidates == nullptr) {
20225+
candidates = &smpl->cur_p;
20226+
}
20227+
2020020228
if (smpl->params.dynatemp_range > 0) {
2020120229
const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range);
2020220230
const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range);
@@ -20210,6 +20238,10 @@ void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array *
2021020238
void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2021120239
time_meas tm(smpl->t_grammar_us);
2021220240

20241+
if (candidates == nullptr) {
20242+
candidates = &smpl->cur_p;
20243+
}
20244+
2021320245
if (smpl->grammar) {
2021420246
llama_sampling_grammar_impl(candidates, *smpl->grammar);
2021520247

@@ -20222,6 +20254,10 @@ void llama_sampling_penalties(
2022220254
llama_token_data_array * candidates) {
2022320255
time_meas tm(smpl->t_sample_us);
2022420256

20257+
if (candidates == nullptr) {
20258+
candidates = &smpl->cur_p;
20259+
}
20260+
2022520261
const size_t penalty_last_n = std::min<size_t>(smpl->params.penalty_last_n, smpl->prev.size());
2022620262

2022720263
const float penalty_repeat = smpl->params.penalty_repeat;
@@ -20246,6 +20282,10 @@ void llama_sampling_penalties(
2024620282
llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2024720283
time_meas tm(smpl->t_sample_us);
2024820284

20285+
if (candidates == nullptr) {
20286+
candidates = &smpl->cur_p;
20287+
}
20288+
2024920289
const auto type = smpl->params.mirostat;
2025020290

2025120291
llama_token res;
@@ -20276,6 +20316,10 @@ llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_t
2027620316
llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2027720317
time_meas tm(smpl->t_sample_us);
2027820318

20319+
if (candidates == nullptr) {
20320+
candidates = &smpl->cur_p;
20321+
}
20322+
2027920323
auto res = llama_sampling_sample_greedy_impl(candidates);
2028020324

2028120325
smpl->n_sample++;
@@ -20286,6 +20330,10 @@ llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_tok
2028620330
llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2028720331
time_meas tm(smpl->t_sample_us);
2028820332

20333+
if (candidates == nullptr) {
20334+
candidates = &smpl->cur_p;
20335+
}
20336+
2028920337
auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng);
2029020338

2029120339
smpl->n_sample++;
@@ -20296,6 +20344,10 @@ llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token
2029620344
llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2029720345
time_meas tm(smpl->t_sample_us);
2029820346

20347+
if (candidates == nullptr) {
20348+
candidates = &smpl->cur_p;
20349+
}
20350+
2029920351
const auto & params = smpl->params;
2030020352

2030120353
const float temp = params.temp;

0 commit comments

Comments
 (0)