5
5
#undef NDEBUG
6
6
#endif
7
7
8
+ #include < algorithm>
8
9
#include < cmath>
9
- #include < numeric>
10
- #include < cassert>
10
+ #include < string>
11
11
#include < vector>
12
- #include < algorithm>
13
12
14
13
static void dump (const llama_token_data_array * candidates) {
15
14
for (size_t i = 0 ; i < candidates->size ; i++) {
@@ -20,11 +19,11 @@ static void dump(const llama_token_data_array * candidates) {
20
19
#define DUMP (__candidates ) do { printf (" %s:%d (%s)\n " , __FILE__, __LINE__, __func__); dump ((__candidates)); printf (" -\n " ); } while (0 )
21
20
22
21
static void test_top_k (const std::vector<float > & probs, const std::vector<float > & expected_probs, int k) {
23
- size_t n_vocab = probs.size ();
22
+ const size_t n_vocab = probs.size ();
24
23
std::vector<llama_token_data> candidates;
25
24
candidates.reserve (n_vocab);
26
25
for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
27
- float logit = log (probs[token_id]);
26
+ const float logit = logf (probs[token_id]);
28
27
candidates.emplace_back (llama_token_data{token_id, logit, 0 .0f });
29
28
}
30
29
@@ -41,11 +40,11 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
41
40
}
42
41
43
42
static void test_top_p (const std::vector<float > & probs, const std::vector<float > & expected_probs, float p) {
44
- size_t n_vocab = probs.size ();
43
+ const size_t n_vocab = probs.size ();
45
44
std::vector<llama_token_data> candidates;
46
45
candidates.reserve (n_vocab);
47
46
for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
48
- float logit = log (probs[token_id]);
47
+ const float logit = logf (probs[token_id]);
49
48
candidates.emplace_back (llama_token_data{token_id, logit, 0 .0f });
50
49
}
51
50
@@ -62,11 +61,11 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
62
61
}
63
62
64
63
static void test_tfs (const std::vector<float > & probs, const std::vector<float > & expected_probs, float z) {
65
- size_t n_vocab = probs.size ();
64
+ const size_t n_vocab = probs.size ();
66
65
std::vector<llama_token_data> candidates;
67
66
candidates.reserve (n_vocab);
68
67
for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
69
- float logit = log (probs[token_id]);
68
+ const float logit = logf (probs[token_id]);
70
69
candidates.emplace_back (llama_token_data{token_id, logit, 0 .0f });
71
70
}
72
71
@@ -81,12 +80,33 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
81
80
}
82
81
}
83
82
83
+ static void test_min_p (const std::vector<float > & probs, const std::vector<float > & expected_probs, float p) {
84
+ const size_t n_vocab = probs.size ();
85
+ std::vector<llama_token_data> candidates;
86
+ candidates.reserve (n_vocab);
87
+ for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
88
+ const float logit = logf (probs[token_id]);
89
+ candidates.emplace_back (llama_token_data{token_id, logit, 0 .0f });
90
+ }
91
+
92
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
93
+ DUMP (&candidates_p);
94
+ llama_sample_min_p (nullptr , &candidates_p, p, 1 );
95
+ DUMP (&candidates_p);
96
+ llama_sample_softmax (nullptr , &candidates_p);
97
+
98
+ GGML_ASSERT (candidates_p.size == expected_probs.size ());
99
+ for (size_t i = 0 ; i < candidates_p.size ; i++) {
100
+ GGML_ASSERT (fabs (candidates_p.data [i].p - expected_probs[i]) < 1e-3 );
101
+ }
102
+ }
103
+
84
104
static void test_typical (const std::vector<float > & probs, const std::vector<float > & expected_probs, float p) {
85
- size_t n_vocab = probs.size ();
105
+ const size_t n_vocab = probs.size ();
86
106
std::vector<llama_token_data> candidates;
87
107
candidates.reserve (n_vocab);
88
108
for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
89
- float logit = log (probs[token_id]);
109
+ const float logit = logf (probs[token_id]);
90
110
candidates.emplace_back (llama_token_data{token_id, logit, 0 .0f });
91
111
}
92
112
@@ -107,11 +127,11 @@ static void test_repetition_penalties(
107
127
) {
108
128
GGML_ASSERT (probs.size () == expected_probs.size ());
109
129
110
- size_t n_vocab = probs.size ();
130
+ const size_t n_vocab = probs.size ();
111
131
std::vector<llama_token_data> candidates;
112
132
candidates.reserve (n_vocab);
113
133
for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
114
- float logit = log (probs[token_id]);
134
+ const float logit = logf (probs[token_id]);
115
135
candidates.emplace_back (llama_token_data{token_id, logit, 0 .0f });
116
136
}
117
137
@@ -128,6 +148,88 @@ static void test_repetition_penalties(
128
148
}
129
149
}
130
150
151
+ static void test_sampler_queue (
152
+ const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
153
+ ) {
154
+ std::vector<llama_token_data> candidates;
155
+ candidates.reserve (n_vocab);
156
+ for (llama_token token_id = 0 ; token_id < (llama_token)n_vocab; token_id++) {
157
+ const float logit = logf (token_id);
158
+ candidates.emplace_back (llama_token_data{token_id, logit, 0 .0f });
159
+ }
160
+
161
+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
162
+
163
+ llama_token min_token_id = 0 ;
164
+ const llama_token max_token_id = n_vocab-1 ;
165
+
166
+ for (auto s : samplers_sequence) {
167
+ switch (s){
168
+ case ' k' : llama_sample_top_k (nullptr , &candidates_p, top_k, 1 ); break ;
169
+ case ' f' : GGML_ASSERT (false && " tail_free test not implemented" ); break ;
170
+ case ' y' : GGML_ASSERT (false && " typical test not implemented" ); break ;
171
+ case ' p' : llama_sample_top_p (nullptr , &candidates_p, top_p, 1 ); break ;
172
+ case ' m' : llama_sample_min_p (nullptr , &candidates_p, min_p, 1 ); break ;
173
+ case ' t' : GGML_ASSERT (false && " temperature test not implemented" ); break ;
174
+ default : GGML_ASSERT (false && " Unknown sampler" ); break ;
175
+ }
176
+
177
+ llama_sample_softmax (nullptr , &candidates_p); // make sure tokens are sorted for tests
178
+
179
+ const int size = candidates_p.size ;
180
+
181
+ if (s == ' k' ) {
182
+ const int expected_size = std::min (size, top_k);
183
+ min_token_id = std::max (min_token_id, (llama_token)(n_vocab - top_k));
184
+
185
+ GGML_ASSERT (size == expected_size);
186
+ GGML_ASSERT (candidates_p.data [0 ].id == max_token_id);
187
+ GGML_ASSERT (candidates_p.data [expected_size-1 ].id == min_token_id);
188
+ } else if (s == ' p' ) {
189
+ const int softmax_divisor = n_vocab * (n_vocab-1 ) / 2 - min_token_id * (min_token_id-1 ) / 2 ;
190
+ const int softmax_numerator_target = ceilf (top_p * softmax_divisor);
191
+
192
+ min_token_id = n_vocab;
193
+ int expected_size = 0 ;
194
+ int cumsum = 0 ;
195
+ do { // do-while because always at least one token is sampled
196
+ min_token_id--;
197
+ expected_size++;
198
+
199
+ cumsum += min_token_id;
200
+ } while (cumsum < softmax_numerator_target);
201
+
202
+ // token 0 has p == 0, need special consideration for cumsum because top_p immediately returns
203
+ if (min_token_id == 1 ) {
204
+ min_token_id--;
205
+ expected_size += 1 ;
206
+ }
207
+
208
+ GGML_ASSERT (size == expected_size);
209
+ GGML_ASSERT (candidates_p.data [0 ].id == max_token_id);
210
+ GGML_ASSERT (candidates_p.data [expected_size-1 ].id == min_token_id);
211
+ } else if (s == ' m' ) {
212
+ int expected_size = ceilf ((1 .0f -min_p) * n_vocab);
213
+ expected_size = std::max (expected_size, 1 );
214
+ expected_size = std::min (expected_size, size);
215
+
216
+ min_token_id = floorf (min_p * n_vocab);
217
+ min_token_id = std::max (min_token_id, 1 );
218
+ min_token_id = std::max (min_token_id, (llama_token)(n_vocab - size));
219
+ min_token_id = std::min (min_token_id, (llama_token)(n_vocab - 1 ));
220
+
221
+ GGML_ASSERT (size == expected_size);
222
+ GGML_ASSERT (candidates_p.data [0 ].id == max_token_id);
223
+ GGML_ASSERT (candidates_p.data [expected_size-1 ].id == min_token_id);
224
+ } else {
225
+ GGML_ASSERT (false );
226
+ }
227
+ }
228
+
229
+ printf (" Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n " ,
230
+ samplers_sequence.c_str (), n_vocab, top_k, top_p, min_p);
231
+ }
232
+
131
233
int main (void ) {
132
234
ggml_time_init ();
133
235
@@ -139,6 +241,15 @@ int main(void) {
139
241
test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f }, 0 .8f );
140
242
test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 1 );
141
243
244
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /1 .0f , 0 .3f /1 .0f , 0 .2f /1 .0f , 0 .1f /1 .0f }, 0 .00f );
245
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /1 .0f , 0 .3f /1 .0f , 0 .2f /1 .0f , 0 .1f /1 .0f }, 0 .24f );
246
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .9f , 0 .3f /0 .9f , 0 .2f /0 .9f }, 0 .26f );
247
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .9f , 0 .3f /0 .9f , 0 .2f /0 .9f }, 0 .49f );
248
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .7f , 0 .3f /0 .7f }, 0 .51f );
249
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .7f , 0 .3f /0 .7f }, 0 .74f );
250
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 0 .76f );
251
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 1 .00f );
252
+
142
253
test_tfs ({0 .1f , 0 .15f , 0 .2f , 0 .25f , 0 .3f }, {0 .3f }, 0 .25f );
143
254
test_tfs ({0 .1f , 0 .15f , 0 .2f , 0 .25f , 0 .3f }, {0 .3f , 0 .25f }, 0 .75f );
144
255
test_tfs ({0 .1f , 0 .15f , 0 .2f , 0 .25f , 0 .3f }, {0 .3f , 0 .25f }, 0 .99f );
@@ -154,6 +265,34 @@ int main(void) {
154
265
test_repetition_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 .499966f , 0 .499966f , 0 .000023f , 0 .000023f , 0 .000023f }, 1 .0f , 5 .0f , 5 .0f );
155
266
test_repetition_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 .499977f , 0 .499977f , 0 .000023f , 0 .000023f , 0 .000000f }, 1 .0f , 5 .0f , 5 .0f );
156
267
268
+ test_sampler_queue (10000 , " k" , 10000 , 1 .0f , 1 .0f );
269
+ test_sampler_queue (10000 , " k" , 1 , 1 .0f , 1 .0f );
270
+ test_sampler_queue (10000 , " p" , 10000 , 1 .0f , 1 .0f );
271
+ test_sampler_queue (10000 , " p" , 10000 , 0 .0f , 1 .0f );
272
+ test_sampler_queue (10000 , " m" , 10000 , 1 .0f , 1 .0f );
273
+ test_sampler_queue (10000 , " m" , 10000 , 1 .0f , 1e-12 );
274
+
275
+ test_sampler_queue (10000 , " k" , 100 , 1 .0000f , 1 .0f );
276
+ test_sampler_queue (10000 , " p" , 10000 , 0 .0002f , 1 .0f );
277
+ test_sampler_queue (10000 , " p" , 10000 , 0 .8000f , 1 .0f );
278
+ test_sampler_queue (10000 , " m" , 10000 , 1 .0000f , 9997 .9f /9999 .0f );
279
+ test_sampler_queue (10000 , " m" , 10000 , 1 .0000f , 0 .1f );
280
+
281
+ test_sampler_queue (10000 , " kp" , 100 , 0 .8f , 0 .1f );
282
+ test_sampler_queue (10000 , " km" , 100 , 0 .8f , 0 .1f );
283
+ test_sampler_queue (10000 , " pk" , 100 , 0 .8f , 0 .1f );
284
+ test_sampler_queue (10000 , " pm" , 100 , 0 .8f , 0 .1f );
285
+ test_sampler_queue (10000 , " mk" , 100 , 0 .8f , 0 .1f );
286
+ test_sampler_queue (10000 , " mp" , 100 , 0 .8f , 9997 .9f /9999 .0f );
287
+ test_sampler_queue (10000 , " mp" , 100 , 0 .8f , 0 .1f );
288
+
289
+ test_sampler_queue (10000 , " kpm" , 100 , 0 .8f , 0 .1f );
290
+ test_sampler_queue (10000 , " kmp" , 100 , 0 .8f , 0 .1f );
291
+ test_sampler_queue (10000 , " pkm" , 100 , 0 .8f , 0 .1f );
292
+ test_sampler_queue (10000 , " pmk" , 100 , 0 .8f , 0 .1f );
293
+ test_sampler_queue (10000 , " mkp" , 100 , 0 .8f , 0 .1f );
294
+ test_sampler_queue (10000 , " mpk" , 100 , 0 .8f , 0 .1f );
295
+
157
296
printf (" OK\n " );
158
297
159
298
return 0 ;
0 commit comments