Skip to content

Commit b0f2736

Browse files
ggerganovslaren
andauthored
sampling : avoid expensive softmax during greedy sampling (#9605)
* sampling : avoid expensive softmax during greedy sampling ggml-ci * speculative : fix default RNG seed + set sparams.n_probs * Update tests/test-sampling.cpp Co-authored-by: slaren <[email protected]> * sampling : add clarifying comment [no ci] --------- Co-authored-by: slaren <[email protected]>
1 parent c087b6f commit b0f2736

File tree

5 files changed

+59
-6
lines changed

5 files changed

+59
-6
lines changed

common/sampling.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
209209
GGML_ASSERT(false && "unknown mirostat version");
210210
}
211211
} else {
212-
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
212+
if (params.n_probs > 0) {
213+
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
214+
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
215+
//
216+
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
217+
// it is much faster, since we avoid sorting all tokens and should give a good approximation
218+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
219+
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
220+
}
213221
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
214222
}
215223

examples/speculative/speculative.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ struct seq_draft {
3232
int main(int argc, char ** argv) {
3333
gpt_params params;
3434

35+
// needed to get candidate probs even for temp <= 0.0
36+
params.sparams.n_probs = 128;
37+
3538
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
3639
return 1;
3740
}
@@ -49,7 +52,7 @@ int main(int argc, char ** argv) {
4952
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
5053
const float p_split = params.p_split;
5154

52-
std::default_random_engine rng(params.sparams.seed);
55+
std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
5356
std::uniform_real_distribution<> u_dist;
5457

5558
// init llama.cpp

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,7 @@ extern "C" {
10661066
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
10671067

10681068
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1069+
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
10691070
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
10701071

10711072
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751

src/llama-sampling.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
#include "llama-vocab.h"
44
#include "llama-grammar.h"
55

6-
#include <cassert>
76
#include <algorithm>
8-
#include <cstring>
9-
#include <ctime>
7+
#include <cassert>
108
#include <cfloat>
119
#include <chrono>
1210
#include <cmath>
11+
#include <cstdlib>
12+
#include <cstring>
13+
#include <ctime>
1314
#include <numeric>
1415
#include <random>
1516
#include <unordered_map>

tests/test-sampling.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "ggml.h"
22
#include "llama.h"
3-
#include "llama-sampling.h"
43

54
#ifdef NDEBUG
65
#undef NDEBUG
@@ -249,6 +248,45 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
249248
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
250249
}
251250

251+
static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {
252+
std::vector<llama_token_data> cur(data.size());
253+
std::copy(data.begin(), data.end(), cur.begin());
254+
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
255+
llama_sampler_apply(cnstr, &cur_p);
256+
llama_sampler_reset(cnstr);
257+
const int64_t t_start = ggml_time_us();
258+
for (int i = 0; i < n_iter; i++) {
259+
std::copy(data.begin(), data.end(), cur.begin());
260+
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
261+
llama_sampler_apply(cnstr, &cur_p);
262+
llama_sampler_reset(cnstr);
263+
}
264+
const int64_t t_end = ggml_time_us();
265+
llama_sampler_free(cnstr);
266+
printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
267+
}
268+
269+
#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
270+
271+
static void test_perf() {
272+
const int n_vocab = 1 << 17;
273+
274+
std::vector<llama_token_data> data;
275+
276+
data.reserve(n_vocab);
277+
for (int i = 0; i < n_vocab; i++) {
278+
const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f);
279+
data.emplace_back(llama_token_data{i, logit, 0.0f});
280+
}
281+
282+
BENCH(llama_sampler_init_top_k (40), data, 32);
283+
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
284+
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
285+
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
286+
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
287+
BENCH(llama_sampler_init_softmax (), data, 32);
288+
}
289+
252290
int main(void) {
253291
ggml_time_init();
254292

@@ -316,5 +354,7 @@ int main(void) {
316354

317355
printf("OK\n");
318356

357+
test_perf();
358+
319359
return 0;
320360
}

0 commit comments

Comments
 (0)