Skip to content

Commit 29f712e

Browse files
committed
llama : update sampling API
ggml-ci
1 parent 4d2873a commit 29f712e

File tree

15 files changed

+510
-584
lines changed

15 files changed

+510
-584
lines changed

common/sampling.cpp

Lines changed: 117 additions & 184 deletions
Large diffs are not rendered by default.

common/sampling.h

Lines changed: 15 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
#include <string>
66
#include <vector>
7-
#include <stdexcept>
87

98
// sampler types
109
enum class llama_sampler_type : char {
@@ -59,119 +58,16 @@ typedef struct gpt_sampling_params {
5958
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
6059
} gpt_sampling_params;
6160

62-
// the ring buffer works similarly to std::deque, but with a fixed capacity
63-
template<typename T>
64-
struct ring_buffer {
65-
ring_buffer() {}
66-
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
67-
68-
T & front() {
69-
if (sz == 0) {
70-
throw std::runtime_error("ring buffer is empty");
71-
}
72-
return data[first];
73-
}
74-
75-
const T & front() const {
76-
if (sz == 0) {
77-
throw std::runtime_error("ring buffer is empty");
78-
}
79-
return data[first];
80-
}
81-
82-
T & back() {
83-
if (sz == 0) {
84-
throw std::runtime_error("ring buffer is empty");
85-
}
86-
return data[pos];
87-
}
88-
89-
const T & back() const {
90-
if (sz == 0) {
91-
throw std::runtime_error("ring buffer is empty");
92-
}
93-
return data[pos];
94-
}
95-
96-
void push_back(const T & value) {
97-
if (sz == capacity) {
98-
// advance the start when buffer is full
99-
first = (first + 1) % capacity;
100-
} else {
101-
sz++;
102-
}
103-
data[pos] = value;
104-
pos = (pos + 1) % capacity;
105-
}
106-
107-
T pop_front() {
108-
if (sz == 0) {
109-
throw std::runtime_error("ring buffer is empty");
110-
}
111-
T value = data[first];
112-
first = (first + 1) % capacity;
113-
sz--;
114-
return value;
115-
}
116-
117-
T & operator[](size_t i) {
118-
if (i >= sz) {
119-
throw std::runtime_error("ring buffer: index out of bounds");
120-
}
121-
return data[(first + i) % capacity];
122-
}
123-
124-
const T & operator[](size_t i) const {
125-
if (i >= sz) {
126-
throw std::runtime_error("ring buffer: index out of bounds");
127-
}
128-
return data[(first + i) % capacity];
129-
}
130-
131-
std::vector<T> to_vector() const {
132-
std::vector<T> result;
133-
result.reserve(sz);
134-
for (size_t i = 0; i < sz; i++) {
135-
result.push_back(data[(first + i) % capacity]);
136-
}
137-
return result;
138-
}
139-
140-
void clear() {
141-
// here only reset the status of the buffer
142-
sz = 0;
143-
first = 0;
144-
pos = 0;
145-
}
146-
147-
bool empty() const {
148-
return sz == 0;
149-
}
150-
151-
size_t size() const {
152-
return sz;
153-
}
154-
155-
size_t capacity = 0;
156-
size_t sz = 0;
157-
size_t first = 0;
158-
size_t pos = 0;
159-
std::vector<T> data;
160-
};
161-
16261
// general sampler context
16362
// TODO: move to llama.h
16463
struct llama_sampling_context {
16564
// parameters that will be used for sampling
16665
gpt_sampling_params params;
16766

168-
// mirostat sampler state
169-
float mirostat_mu;
170-
17167
llama_sampling * smpl;
17268

173-
ring_buffer<llama_token> prev;
17469
std::vector<llama_token_data> cur;
70+
std::vector<llama_token_data> org;
17571

17672
size_t n_valid; // Number of correct top tokens with correct probabilities.
17773
};
@@ -189,10 +85,10 @@ void llama_sampling_reset(llama_sampling_context * ctx);
18985
// Copy the sampler context
19086
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
19187

192-
// Get the last sampled token
88+
// Get the last accepted token
19389
llama_token llama_sampling_last(llama_sampling_context * ctx);
19490

195-
// Get a string representation of the last sampled tokens
91+
// Get a string representation of the last accepted tokens
19692
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
19793

19894
// Print sampling parameters into a string
@@ -206,6 +102,13 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
206102
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
207103
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
208104

105+
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
106+
llama_token_data_array llama_sampling_prepare(
107+
struct llama_sampling_context * ctx_sampling,
108+
struct llama_context * ctx_main,
109+
struct llama_context * ctx_cfg,
110+
int idx = 0);
111+
209112
// this is a common sampling function used across the examples for convenience
210113
// it can serve as a starting point for implementing your own sampling function
211114
// Note: When using multiple sequences, it is the caller's responsibility to call
@@ -223,20 +126,15 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
223126
// - token: sampled token
224127
// - candidates: vector of candidate tokens
225128
//
226-
llama_token llama_sampling_sample(
227-
struct llama_sampling_context * ctx_sampling,
228-
struct llama_context * ctx_main,
229-
struct llama_context * ctx_cfg,
230-
int idx = -1);
129+
//llama_token llama_sampling_sample(
130+
// struct llama_sampling_context * ctx_sampling,
131+
// struct llama_token_data_array * cur_p);
231132

232-
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
233-
llama_token_data_array llama_sampling_prepare(
133+
llama_token llama_sampling_sample(
234134
struct llama_sampling_context * ctx_sampling,
235135
struct llama_context * ctx_main,
236136
struct llama_context * ctx_cfg,
237-
int idx = 0,
238-
bool apply_grammar = true,
239-
std::vector<float> * original_logits = nullptr);
137+
int idx = 0);
240138

241139
void llama_sampling_accept(
242140
struct llama_sampling_context * ctx_sampling,

examples/batched.swift/Sources/main.swift

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ defer {
5050
llama_free(context)
5151
}
5252

53-
let smpl = llama_sampling_init(model, llama_sampling_default_params())
53+
var sparams = llama_sampling_params()
54+
sparams.top_k = 40
55+
sparams.top_p = 0.9
56+
sparams.temp = 0.4
57+
58+
let smpl = llama_sampling_init(model, sparams)
5459
guard smpl != nil else {
5560
print("Failed to initialize sampling")
5661
exit(1)
@@ -146,13 +151,9 @@ while n_cur <= n_len {
146151
sorted: false
147152
)
148153

149-
let top_k: Int32 = 40
150-
let top_p: Float = 0.9
151-
let temp: Float = 0.4
152-
153-
llama_sampling_top_k(smpl, &candidates_p, top_k, 1)
154-
llama_sampling_top_p(smpl, &candidates_p, top_p, 1)
155-
llama_sampling_temp(smpl, &candidates_p, temp)
154+
llama_sampling_top_k(smpl, &candidates_p)
155+
llama_sampling_top_p(smpl, &candidates_p)
156+
llama_sampling_temp (smpl, &candidates_p)
156157

157158
let new_token_id = llama_sampling_sample(smpl, &candidates_p)
158159

examples/batched/batched.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,13 @@ int main(int argc, char ** argv) {
6464
ctx_params.n_batch = std::max(n_predict, n_parallel);
6565

6666
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
67-
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
67+
68+
auto sparams = llama_sampling_default_params();
69+
sparams.top_k = 40;
70+
sparams.top_p = 0.9f;
71+
sparams.temp = 0.4f;
72+
73+
llama_sampling * smpl = llama_sampling_init(model, sparams);
6874

6975
if (ctx == NULL) {
7076
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
@@ -177,13 +183,9 @@ int main(int argc, char ** argv) {
177183

178184
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
179185

180-
const int top_k = 40;
181-
const float top_p = 0.9f;
182-
const float temp = 0.4f;
183-
184-
llama_sampling_top_k(smpl, &candidates_p, top_k, 1);
185-
llama_sampling_top_p(smpl, &candidates_p, top_p, 1);
186-
llama_sampling_temp (smpl, &candidates_p, temp);
186+
llama_sampling_top_k(smpl, &candidates_p);
187+
llama_sampling_top_p(smpl, &candidates_p);
188+
llama_sampling_temp (smpl, &candidates_p);
187189

188190
const llama_token new_token_id = llama_sampling_sample(smpl, &candidates_p);
189191

examples/server/server.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2360,9 +2360,10 @@ struct server_context {
23602360
const size_t n_valid = slot.ctx_sampling->n_valid;
23612361

23622362
// Make sure at least n_probs top tokens are at the front of the vector:
2363-
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
2364-
llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0);
2365-
}
2363+
// TODO: decide to how to handle this after the refactoring
2364+
//if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
2365+
// llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0);
2366+
//}
23662367

23672368
if (slot.sparams.temp == 0.0f) {
23682369
// With greedy sampling the probabilities have possibly not been calculated.

examples/speculative/speculative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ int main(int argc, char ** argv) {
181181
// draft sequence data
182182
std::vector<seq_draft> drafts(n_seq_dft);
183183

184-
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
185184
if (params.sparams.temp == 0) {
186185
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
187186
}
@@ -231,7 +230,8 @@ int main(int argc, char ** argv) {
231230
if (params.sparams.temp > 0) {
232231
// stochastic verification
233232

234-
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
233+
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
234+
llama_sampling_grammar(ctx_sampling->smpl, &dist_tgt);
235235
llama_sampling_softmax(ctx_sampling->smpl, &dist_tgt);
236236

237237
float p_tgt = 0.0f;

0 commit comments

Comments
 (0)