Skip to content

Commit 28103f4

Browse files
Server: fix seed for multiple slots (#6835)
* Server: add tests for consistent results * sampling: separate rng per sampling context
1 parent c0d1b3e commit 28103f4

File tree

11 files changed

+145
-30
lines changed

11 files changed

+145
-30
lines changed

common/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
242242
invalid_param = true;
243243
return true;
244244
}
245+
// This is temporary, in the future the samplign state will be moved fully to llama_sampling_context.
245246
params.seed = std::stoul(argv[i]);
247+
sparams.seed = std::stoul(argv[i]);
246248
return true;
247249
}
248250
if (arg == "-t" || arg == "--threads") {

common/sampling.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
#define LLAMA_API_INTERNAL
12
#include "sampling.h"
3+
#include <random>
24

35
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
46
struct llama_sampling_context * result = new llama_sampling_context();
@@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
3335

3436
result->prev.resize(params.n_prev);
3537

38+
llama_sampling_set_rng_seed(result, params.seed);
39+
3640
return result;
3741
}
3842

@@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
6266
ctx->cur.clear();
6367
}
6468

69+
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
70+
if (seed == LLAMA_DEFAULT_SEED) {
71+
seed = time(NULL);
72+
}
73+
ctx->rng.seed(seed);
74+
}
75+
6576
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
6677
if (dst->grammar) {
6778
llama_grammar_free(dst->grammar);
@@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl(
203214

204215
sampler_queue(ctx_main, params, cur_p, min_keep);
205216

206-
id = llama_sample_token(ctx_main, &cur_p);
217+
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
207218

208219
//{
209220
// const int n_top = 10;

common/sampling.h

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
#include "grammar-parser.h"
66

7+
#include <random>
78
#include <string>
8-
#include <vector>
99
#include <unordered_map>
10+
#include <vector>
1011

1112
// sampler types
1213
enum class llama_sampler_type : char {
@@ -20,25 +21,26 @@ enum class llama_sampler_type : char {
2021

2122
// sampling parameters
2223
typedef struct llama_sampling_params {
23-
int32_t n_prev = 64; // number of previous tokens to remember
24-
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
25-
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
26-
int32_t top_k = 40; // <= 0 to use vocab size
27-
float top_p = 0.95f; // 1.0 = disabled
28-
float min_p = 0.05f; // 0.0 = disabled
29-
float tfs_z = 1.00f; // 1.0 = disabled
30-
float typical_p = 1.00f; // 1.0 = disabled
31-
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
32-
float dynatemp_range = 0.00f; // 0.0 = disabled
33-
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
34-
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
35-
float penalty_repeat = 1.00f; // 1.0 = disabled
36-
float penalty_freq = 0.00f; // 0.0 = disabled
37-
float penalty_present = 0.00f; // 0.0 = disabled
38-
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
39-
float mirostat_tau = 5.00f; // target entropy
40-
float mirostat_eta = 0.10f; // learning rate
41-
bool penalize_nl = false; // consider newlines as a repeatable token
24+
int32_t n_prev = 64; // number of previous tokens to remember
25+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
26+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
27+
int32_t top_k = 40; // <= 0 to use vocab size
28+
float top_p = 0.95f; // 1.0 = disabled
29+
float min_p = 0.05f; // 0.0 = disabled
30+
float tfs_z = 1.00f; // 1.0 = disabled
31+
float typical_p = 1.00f; // 1.0 = disabled
32+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
33+
float dynatemp_range = 0.00f; // 0.0 = disabled
34+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
35+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
36+
float penalty_repeat = 1.00f; // 1.0 = disabled
37+
float penalty_freq = 0.00f; // 0.0 = disabled
38+
float penalty_present = 0.00f; // 0.0 = disabled
39+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
40+
float mirostat_tau = 5.00f; // target entropy
41+
float mirostat_eta = 0.10f; // learning rate
42+
bool penalize_nl = false; // consider newlines as a repeatable token
43+
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
4244

4345
std::vector<llama_sampler_type> samplers_sequence = {
4446
llama_sampler_type::TOP_K,
@@ -79,6 +81,8 @@ struct llama_sampling_context {
7981
// TODO: replace with ring-buffer
8082
std::vector<llama_token> prev;
8183
std::vector<llama_token_data> cur;
84+
85+
std::mt19937 rng;
8286
};
8387

8488
#include "common.h"
@@ -93,6 +97,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
9397
// - reset grammar
9498
void llama_sampling_reset(llama_sampling_context * ctx);
9599

100+
// Set the sampler seed
101+
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
102+
96103
// Copy the sampler context
97104
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
98105

examples/lookup/lookup-stats.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ int main(int argc, char ** argv){
3030

3131
// load the model
3232
std::tie(model, ctx) = llama_init_from_gpt_params(params);
33-
llama_set_rng_seed(ctx, params.seed);
3433
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
3534

3635
// tokenize the prompt

examples/lookup/lookup.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ int main(int argc, char ** argv){
3838

3939
// load the model
4040
std::tie(model, ctx) = llama_init_from_gpt_params(params);
41-
llama_set_rng_seed(ctx, params.seed);
4241
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
4342

4443
// tokenize the prompt

examples/main/main.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ int main(int argc, char ** argv) {
240240
return 1;
241241
}
242242
session_tokens.resize(n_token_count_out);
243-
llama_set_rng_seed(ctx, params.seed);
244243
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
245244
}
246245
}

examples/server/server.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ struct server_context {
854854
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
855855
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
856856
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
857-
slot.params.seed = json_value(data, "seed", default_params.seed);
857+
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
858858
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
859859
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
860860

@@ -1028,7 +1028,6 @@ struct server_context {
10281028
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
10291029
return false;
10301030
}
1031-
llama_set_rng_seed(ctx, slot.params.seed);
10321031
}
10331032

10341033
slot.command = SLOT_COMMAND_LOAD_PROMPT;
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
@llama.cpp
2+
@results
3+
Feature: Results
4+
5+
Background: Server startup
6+
Given a server listening on localhost:8080
7+
And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
8+
And a model file test-model-00001-of-00003.gguf
9+
And 128 as batch size
10+
And 256 KV cache size
11+
And 128 max tokens to predict
12+
13+
Scenario Outline: Multi users completion
14+
Given <n_slots> slots
15+
And continuous batching
16+
Then the server is starting
17+
Then the server is healthy
18+
19+
Given 42 as seed
20+
And a prompt:
21+
"""
22+
Write a very long story about AI.
23+
"""
24+
25+
Given 42 as seed
26+
And a prompt:
27+
"""
28+
Write a very long story about AI.
29+
"""
30+
31+
Given 42 as seed
32+
And a prompt:
33+
"""
34+
Write a very long story about AI.
35+
"""
36+
37+
Given 42 as seed
38+
And a prompt:
39+
"""
40+
Write a very long story about AI.
41+
"""
42+
43+
Given 42 as seed
44+
And a prompt:
45+
"""
46+
Write a very long story about AI.
47+
"""
48+
49+
Given concurrent completion requests
50+
Then the server is busy
51+
Then the server is idle
52+
And all slots are idle
53+
Then all predictions are equal
54+
Examples:
55+
| n_slots |
56+
| 1 |
57+
| 2 |

examples/server/tests/features/steps/steps.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port):
6161
context.server_metrics = False
6262
context.server_process = None
6363
context.seed = None
64+
context.draft = None
6465
context.server_seed = None
6566
context.user_api_key = None
6667
context.response_format = None
@@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl):
107108
context.n_gpu_layer = ngl
108109

109110

111+
@step('{draft:d} as draft')
112+
def step_draft(context, draft):
113+
context.draft = draft
114+
115+
110116
@step('{n_ctx:d} KV cache size')
111117
def step_n_ctx(context, n_ctx):
112118
context.n_ctx = n_ctx
@@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n):
254260
assert_n_tokens_predicted(context.completion, predicted_n)
255261

256262

263+
@step('all predictions are equal')
264+
@async_run_until_complete
265+
async def step_predictions_equal(context):
266+
n_completions = await gather_tasks_results(context)
267+
assert n_completions >= 2, "need at least 2 completions"
268+
assert_all_predictions_equal(context.tasks_result)
269+
context.tasks_result = []
270+
271+
257272
@step('the completion is truncated')
258273
def step_assert_completion_truncated(context):
259274
step_assert_completion_truncated(context, '')
@@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
10201035
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
10211036
f' {n_predicted} <> {expected_predicted_n}')
10221037

1038+
def assert_all_predictions_equal(completion_responses):
1039+
content_0 = completion_responses[0]['content']
1040+
1041+
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
1042+
print(f"content 0: {content_0}")
1043+
1044+
i = 1
1045+
for response in completion_responses[1:]:
1046+
content = response['content']
1047+
1048+
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
1049+
print(f"content {i}: {content}")
1050+
1051+
assert content == content_0, "contents not equal"
1052+
1053+
i += 1
1054+
10231055

10241056
async def gather_tasks_results(context):
10251057
n_tasks = len(context.concurrent_tasks)
@@ -1148,6 +1180,8 @@ def start_server_background(context):
11481180
server_args.extend(['--ubatch-size', context.n_ubatch])
11491181
if context.n_gpu_layer:
11501182
server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
1183+
if context.draft is not None:
1184+
server_args.extend(['--draft', context.draft])
11511185
if context.server_continuous_batching:
11521186
server_args.append('--cont-batching')
11531187
if context.server_embeddings:

llama.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13667,7 +13667,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
1366713667
return result;
1366813668
}
1366913669

13670-
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
13670+
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
1367113671
GGML_ASSERT(ctx);
1367213672

1367313673
const int64_t t_start_sample_us = ggml_time_us();
@@ -13680,7 +13680,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
1368013680
}
1368113681

1368213682
std::discrete_distribution<> dist(probs.begin(), probs.end());
13683-
auto & rng = ctx->rng;
1368413683
int idx = dist(rng);
1368513684

1368613685
llama_token result = candidates->data[idx].id;
@@ -13690,6 +13689,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
1369013689
return result;
1369113690
}
1369213691

13692+
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
13693+
return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
13694+
}
13695+
1369313696
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
1369413697
const int64_t t_start_sample_us = ggml_time_us();
1369513698

llama.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,7 @@ extern "C" {
987987
struct llama_context * ctx,
988988
llama_token_data_array * candidates);
989989

990-
/// @details Randomly selects a token from the candidates based on their probabilities.
990+
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
991991
LLAMA_API llama_token llama_sample_token(
992992
struct llama_context * ctx,
993993
llama_token_data_array * candidates);
@@ -1074,8 +1074,9 @@ extern "C" {
10741074
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
10751075
#ifdef LLAMA_API_INTERNAL
10761076

1077-
#include <vector>
1077+
#include <random>
10781078
#include <string>
1079+
#include <vector>
10791080

10801081
struct ggml_tensor;
10811082

@@ -1112,6 +1113,10 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
11121113
const std::string & src,
11131114
llama_partial_utf8 partial_start);
11141115

1116+
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
1117+
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
1118+
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
1119+
11151120
#endif // LLAMA_API_INTERNAL
11161121

11171122
#endif // LLAMA_H

0 commit comments

Comments
 (0)