Skip to content

Commit 123eaf0

Browse files
sampling: separate rng per sampling context
1 parent b1a1891 commit 123eaf0

File tree

8 files changed

+34
-10
lines changed

8 files changed

+34
-10
lines changed

common/sampling.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
#define LLAMA_API_INTERNAL
12
#include "sampling.h"
3+
#include <random>
24

35
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
46
struct llama_sampling_context * result = new llama_sampling_context();
@@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
3335

3436
result->prev.resize(params.n_prev);
3537

38+
llama_sampling_set_rng_seed(result, LLAMA_DEFAULT_SEED);
39+
3640
return result;
3741
}
3842

@@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
6266
ctx->cur.clear();
6367
}
6468

69+
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
70+
if (seed == LLAMA_DEFAULT_SEED) {
71+
seed = time(NULL);
72+
}
73+
ctx->rng.seed(seed);
74+
}
75+
6576
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
6677
if (dst->grammar) {
6778
llama_grammar_free(dst->grammar);
@@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl(
203214

204215
sampler_queue(ctx_main, params, cur_p, min_keep);
205216

206-
id = llama_sample_token(ctx_main, &cur_p);
217+
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
207218

208219
//{
209220
// const int n_top = 10;

common/sampling.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
#include "grammar-parser.h"
66

7+
#include <random>
78
#include <string>
8-
#include <vector>
99
#include <unordered_map>
10+
#include <vector>
1011

1112
// sampler types
1213
enum class llama_sampler_type : char {
@@ -79,6 +80,8 @@ struct llama_sampling_context {
7980
// TODO: replace with ring-buffer
8081
std::vector<llama_token> prev;
8182
std::vector<llama_token_data> cur;
83+
84+
std::mt19937 rng;
8285
};
8386

8487
#include "common.h"
@@ -93,6 +96,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
9396
// - reset grammar
9497
void llama_sampling_reset(llama_sampling_context * ctx);
9598

99+
// Set the sampler seed
100+
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
101+
96102
// Copy the sampler context
97103
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
98104

examples/lookup/lookup-stats.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ int main(int argc, char ** argv){
3030

3131
// load the model
3232
std::tie(model, ctx) = llama_init_from_gpt_params(params);
33-
llama_set_rng_seed(ctx, params.seed);
3433
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
3534

3635
// tokenize the prompt

examples/lookup/lookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ int main(int argc, char ** argv){
3838

3939
// load the model
4040
std::tie(model, ctx) = llama_init_from_gpt_params(params);
41-
llama_set_rng_seed(ctx, params.seed);
4241
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
4342

4443
// tokenize the prompt
@@ -108,6 +107,7 @@ int main(int argc, char ** argv){
108107
bool has_eos = false;
109108

110109
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
110+
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
111111

112112
std::vector<llama_token> draft;
113113

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ int main(int argc, char ** argv) {
240240
return 1;
241241
}
242242
session_tokens.resize(n_token_count_out);
243-
llama_set_rng_seed(ctx, params.seed);
244243
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
245244
}
246245
}
@@ -521,6 +520,7 @@ int main(int argc, char ** argv) {
521520
}
522521

523522
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
523+
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
524524

525525
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
526526
// predict

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1028,7 +1028,7 @@ struct server_context {
10281028
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
10291029
return false;
10301030
}
1031-
llama_set_rng_seed(ctx, slot.params.seed);
1031+
llama_sampling_set_rng_seed(slot.ctx_sampling, slot.params.seed);
10321032
}
10331033

10341034
slot.command = SLOT_COMMAND_LOAD_PROMPT;

llama.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13478,7 +13478,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
1347813478
return result;
1347913479
}
1348013480

13481-
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
13481+
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
1348213482
GGML_ASSERT(ctx);
1348313483

1348413484
const int64_t t_start_sample_us = ggml_time_us();
@@ -13491,7 +13491,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
1349113491
}
1349213492

1349313493
std::discrete_distribution<> dist(probs.begin(), probs.end());
13494-
auto & rng = ctx->rng;
1349513494
int idx = dist(rng);
1349613495

1349713496
llama_token result = candidates->data[idx].id;
@@ -13501,6 +13500,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
1350113500
return result;
1350213501
}
1350313502

13503+
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
13504+
return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
13505+
}
13506+
1350413507
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
1350513508
const int64_t t_start_sample_us = ggml_time_us();
1350613509

llama.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,7 @@ extern "C" {
987987
struct llama_context * ctx,
988988
llama_token_data_array * candidates);
989989

990-
/// @details Randomly selects a token from the candidates based on their probabilities.
990+
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
991991
LLAMA_API llama_token llama_sample_token(
992992
struct llama_context * ctx,
993993
llama_token_data_array * candidates);
@@ -1074,8 +1074,9 @@ extern "C" {
10741074
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
10751075
#ifdef LLAMA_API_INTERNAL
10761076

1077-
#include <vector>
1077+
#include <random>
10781078
#include <string>
1079+
#include <vector>
10791080

10801081
struct ggml_tensor;
10811082

@@ -1112,6 +1113,10 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
11121113
const std::string & src,
11131114
llama_partial_utf8 partial_start);
11141115

1116+
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
1117+
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
1118+
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
1119+
11151120
#endif // LLAMA_API_INTERNAL
11161121

11171122
#endif // LLAMA_H

0 commit comments

Comments
 (0)