Skip to content

Commit 5ae3426

Browse files
server: fix reported top tokens for temperature 0 (#7203)
1 parent b83cc3f commit 5ae3426

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

common/sampling.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
3535

3636
result->prev.resize(params.n_prev);
3737

38-
result->n_considered = 0;
38+
result->n_valid = 0;
3939

4040
llama_sampling_set_rng_seed(result, params.seed);
4141

@@ -66,7 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
6666

6767
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
6868
ctx->cur.clear();
69-
ctx->n_considered = 0;
69+
ctx->n_valid = 0;
7070
}
7171

7272
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
@@ -256,7 +256,7 @@ static llama_token llama_sampling_sample_impl(
256256
}
257257
}
258258

259-
ctx_sampling->n_considered = cur_p.size;
259+
ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
260260

261261
return id;
262262
}

common/sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ struct llama_sampling_context {
8181
// TODO: replace with ring-buffer
8282
std::vector<llama_token> prev;
8383
std::vector<llama_token_data> cur;
84-
size_t n_considered;
84+
size_t n_valid; // Number of correct top tokens with correct probabilities.
8585

8686
std::mt19937 rng;
8787
};

examples/server/server.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,10 +2270,10 @@ struct server_context {
22702270

22712271
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
22722272
if (n_probs > 0) {
2273-
const size_t n_considered = slot.ctx_sampling->n_considered;
2273+
const size_t n_valid = slot.ctx_sampling->n_valid;
22742274

22752275
// Make sure at least n_probs top tokens are at the front of the vector:
2276-
if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
2276+
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
22772277
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
22782278
}
22792279

@@ -2289,7 +2289,7 @@ struct server_context {
22892289
for (size_t i = 0; i < n_probs; ++i) {
22902290
result.probs.push_back({
22912291
cur_p.data[i].id,
2292-
i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
2292+
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
22932293
});
22942294
}
22952295
}

0 commit comments

Comments
 (0)