Skip to content

Commit 05ce8ce

Browse files
committed
sampling : remove n_valid from the state
ggml-ci
1 parent 650adf1 commit 05ce8ce

File tree

5 files changed

+39
-64
lines changed

5 files changed

+39
-64
lines changed

common/common.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,16 +249,14 @@ void gpt_params_handle_model_default(gpt_params & params) {
249249
}
250250

251251
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
252-
bool invalid_param = false;
253-
std::string arg;
254-
const std::string arg_prefix = "--";
255-
auto & sparams = params.sparams;
256-
257252
for (int i = 1; i < argc; i++) {
258-
arg = argv[i];
253+
const std::string arg_prefix = "--";
254+
255+
std::string arg = argv[i];
259256
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
260257
std::replace(arg.begin(), arg.end(), '_', '-');
261258
}
259+
bool invalid_param = false;
262260
if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
263261
throw std::invalid_argument("error: unknown argument: " + arg);
264262
}
@@ -275,6 +273,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
275273

276274
gpt_params_handle_hf_token(params);
277275

276+
auto & sparams = params.sparams;
277+
278278
if (params.escape) {
279279
string_process_escapes(params.prompt);
280280
string_process_escapes(params.input_prefix);

common/sampling.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
4040
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
4141
}
4242

43-
result->n_valid = 0;
44-
4543
return result;
4644
}
4745

@@ -55,7 +53,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
5553
llama_sampling_reset(ctx->smpl);
5654

5755
ctx->cur.clear();
58-
ctx->n_valid = 0;
56+
ctx->org.clear();
5957
}
6058

6159
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
@@ -294,11 +292,11 @@ static llama_token llama_sampling_sample(
294292

295293
llama_token id = 0;
296294

297-
if (temp < 0.0) {
295+
if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) {
298296
// greedy sampling, with probs
299297
llama_sampling_softmax(smpl, cur_p);
300298
id = cur_p->data[0].id;
301-
} else if (temp == 0.0) {
299+
} else if (temp == 0.0f) {
302300
// greedy sampling, no probs
303301
id = llama_sampling_sample_greedy(smpl, cur_p);
304302
} else {
@@ -325,8 +323,6 @@ static llama_token llama_sampling_sample(
325323
}
326324
}
327325

328-
ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p->size;
329-
330326
return id;
331327
}
332328

@@ -341,7 +337,7 @@ llama_token llama_sampling_sample(
341337
return llama_sampling_sample(ctx_sampling, &cur_p);
342338
}
343339

344-
// TODO: this lofic is confusing, try to figure out a better way to handle this
340+
// TODO: this logic is confusing, try to figure out a better way to handle this
345341

346342
// store the original candidates
347343
ctx_sampling->org = ctx_sampling->cur;

common/sampling.h

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,27 @@ enum class llama_sampler_type : char {
1717

1818
// sampling parameters
1919
typedef struct gpt_sampling_params {
20-
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
21-
int32_t n_prev = 64; // number of previous tokens to remember
22-
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
23-
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
24-
int32_t top_k = 40; // <= 0 to use vocab size
25-
float top_p = 0.95f; // 1.0 = disabled
26-
float min_p = 0.05f; // 0.0 = disabled
27-
float tfs_z = 1.00f; // 1.0 = disabled
28-
float typical_p = 1.00f; // 1.0 = disabled
29-
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
30-
float dynatemp_range = 0.00f; // 0.0 = disabled
31-
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
32-
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
33-
float penalty_repeat = 1.00f; // 1.0 = disabled
34-
float penalty_freq = 0.00f; // 0.0 = disabled
35-
float penalty_present = 0.00f; // 0.0 = disabled
36-
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
37-
float mirostat_tau = 5.00f; // target entropy
38-
float mirostat_eta = 0.10f; // learning rate
39-
bool penalize_nl = false; // consider newlines as a repeatable token
40-
bool ignore_eos = false;
20+
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
21+
int32_t n_prev = 64; // number of previous tokens to remember
22+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
23+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
24+
int32_t top_k = 40; // <= 0 to use vocab size
25+
float top_p = 0.95f; // 1.0 = disabled
26+
float min_p = 0.05f; // 0.0 = disabled
27+
float tfs_z = 1.00f; // 1.0 = disabled
28+
float typical_p = 1.00f; // 1.0 = disabled
29+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
30+
float dynatemp_range = 0.00f; // 0.0 = disabled
31+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
32+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
33+
float penalty_repeat = 1.00f; // 1.0 = disabled
34+
float penalty_freq = 0.00f; // 0.0 = disabled
35+
float penalty_present = 0.00f; // 0.0 = disabled
36+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
37+
float mirostat_tau = 5.00f; // target entropy
38+
float mirostat_eta = 0.10f; // learning rate
39+
bool penalize_nl = false; // consider newlines as a repeatable token
40+
bool ignore_eos = false;
4141

4242
std::vector<llama_sampler_type> samplers_sequence = {
4343
llama_sampler_type::TOP_K,
@@ -68,8 +68,6 @@ struct llama_sampling_context {
6868

6969
std::vector<llama_token_data> cur;
7070
std::vector<llama_token_data> org;
71-
72-
size_t n_valid; // Number of correct top tokens with correct probabilities.
7371
};
7472

7573
// Create a new sampling context instance.

examples/server/server.cpp

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2352,35 +2352,15 @@ struct server_context {
23522352
metrics.on_prompt_eval(slot);
23532353
}
23542354

2355-
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
23562355
result.tok = id;
23572356

2358-
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
2359-
if (n_probs > 0) {
2360-
const size_t n_valid = slot.ctx_sampling->n_valid;
2361-
2362-
// Make sure at least n_probs top tokens are at the front of the vector:
2363-
// TODO: decide to how to handle this after the refactoring
2364-
//if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
2365-
// llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0);
2366-
//}
2367-
2368-
if (slot.sparams.temp == 0.0f) {
2369-
// With greedy sampling the probabilities have possibly not been calculated.
2370-
for (size_t i = 0; i < n_probs; ++i) {
2371-
result.probs.push_back({
2372-
cur_p.data[i].id,
2373-
i == 0 ? 1.0f : 0.0f
2374-
});
2375-
}
2376-
} else {
2377-
for (size_t i = 0; i < n_probs; ++i) {
2378-
result.probs.push_back({
2379-
cur_p.data[i].id,
2380-
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
2381-
});
2382-
}
2383-
}
2357+
const llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
2358+
2359+
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
2360+
result.probs.push_back({
2361+
cur_p.data[i].id,
2362+
i >= cur_p.size ? 0.0f : cur_p.data[i].p,
2363+
});
23842364
}
23852365

23862366
if (!process_token(result, slot)) {

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ extern "C" {
210210
} llama_token_data;
211211

212212
typedef struct llama_token_data_array {
213+
// TODO: consider SoA
213214
llama_token_data * data;
214215
size_t size;
215216
bool sorted;

0 commit comments

Comments
 (0)