Skip to content

Commit 4d2873a

Browse files
kylo5abyggerganov
authored andcommitted
sampling : use ring buffer to store prev tokens (#8890)
1 parent e8dbe04 commit 4d2873a

File tree

4 files changed

+110
-8
lines changed

4 files changed

+110
-8
lines changed

common/sampling.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
4040
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
4141
}
4242

43-
result->prev.resize(params.n_prev);
43+
result->prev = ring_buffer<llama_token>(params.n_prev);
4444

4545
result->n_valid = 0;
4646

@@ -56,7 +56,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
5656
void llama_sampling_reset(llama_sampling_context * ctx) {
5757
llama_sampling_reset(ctx->smpl);
5858

59-
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
59+
ctx->prev.clear();
6060
ctx->cur.clear();
6161
ctx->n_valid = 0;
6262
}
@@ -384,7 +384,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
384384
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
385385

386386
// apply penalties
387-
const auto & penalty_tokens = prev;
387+
const auto & penalty_tokens = prev.to_vector();
388388
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
389389
if (penalty_tokens_used_size) {
390390
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
@@ -434,7 +434,9 @@ void llama_sampling_accept(
434434
struct llama_sampling_context * ctx_sampling,
435435
llama_token id,
436436
bool apply_grammar) {
437-
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
437+
if (!ctx_sampling->prev.empty()) {
438+
ctx_sampling->prev.pop_front();
439+
}
438440
ctx_sampling->prev.push_back(id);
439441

440442
if (apply_grammar) {

common/sampling.h

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <string>
66
#include <vector>
7+
#include <stdexcept>
78

89
// sampler types
910
enum class llama_sampler_type : char {
@@ -58,6 +59,106 @@ typedef struct gpt_sampling_params {
5859
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
5960
} gpt_sampling_params;
6061

62+
// the ring buffer works similarly to std::deque, but with a fixed capacity
63+
template<typename T>
64+
struct ring_buffer {
65+
ring_buffer() {}
66+
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
67+
68+
T & front() {
69+
if (sz == 0) {
70+
throw std::runtime_error("ring buffer is empty");
71+
}
72+
return data[first];
73+
}
74+
75+
const T & front() const {
76+
if (sz == 0) {
77+
throw std::runtime_error("ring buffer is empty");
78+
}
79+
return data[first];
80+
}
81+
82+
T & back() {
83+
if (sz == 0) {
84+
throw std::runtime_error("ring buffer is empty");
85+
}
86+
return data[pos];
87+
}
88+
89+
const T & back() const {
90+
if (sz == 0) {
91+
throw std::runtime_error("ring buffer is empty");
92+
}
93+
return data[pos];
94+
}
95+
96+
void push_back(const T & value) {
97+
if (sz == capacity) {
98+
// advance the start when buffer is full
99+
first = (first + 1) % capacity;
100+
} else {
101+
sz++;
102+
}
103+
data[pos] = value;
104+
pos = (pos + 1) % capacity;
105+
}
106+
107+
T pop_front() {
108+
if (sz == 0) {
109+
throw std::runtime_error("ring buffer is empty");
110+
}
111+
T value = data[first];
112+
first = (first + 1) % capacity;
113+
sz--;
114+
return value;
115+
}
116+
117+
T & operator[](size_t i) {
118+
if (i >= sz) {
119+
throw std::runtime_error("ring buffer: index out of bounds");
120+
}
121+
return data[(first + i) % capacity];
122+
}
123+
124+
const T & operator[](size_t i) const {
125+
if (i >= sz) {
126+
throw std::runtime_error("ring buffer: index out of bounds");
127+
}
128+
return data[(first + i) % capacity];
129+
}
130+
131+
std::vector<T> to_vector() const {
132+
std::vector<T> result;
133+
result.reserve(sz);
134+
for (size_t i = 0; i < sz; i++) {
135+
result.push_back(data[(first + i) % capacity]);
136+
}
137+
return result;
138+
}
139+
140+
void clear() {
141+
// here only reset the status of the buffer
142+
sz = 0;
143+
first = 0;
144+
pos = 0;
145+
}
146+
147+
bool empty() const {
148+
return sz == 0;
149+
}
150+
151+
size_t size() const {
152+
return sz;
153+
}
154+
155+
size_t capacity = 0;
156+
size_t sz = 0;
157+
size_t first = 0;
158+
size_t pos = 0;
159+
std::vector<T> data;
160+
};
161+
61162
// general sampler context
62163
// TODO: move to llama.h
63164
struct llama_sampling_context {
@@ -69,8 +170,7 @@ struct llama_sampling_context {
69170

70171
llama_sampling * smpl;
71172

72-
// TODO: replace with ring-buffer
73-
std::vector<llama_token> prev;
173+
ring_buffer<llama_token> prev;
74174
std::vector<llama_token_data> cur;
75175

76176
size_t n_valid; // Number of correct top tokens with correct probabilities.

examples/infill/infill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ int main(int argc, char ** argv) {
421421

422422
llama_sampling_accept(ctx_sampling, id, true);
423423

424-
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
424+
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
425425

426426
embd.push_back(id);
427427

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ int main(int argc, char ** argv) {
733733

734734
llama_sampling_accept(ctx_sampling, id, /* apply_grammar= */ true);
735735

736-
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
736+
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
737737

738738
embd.push_back(id);
739739

0 commit comments

Comments
 (0)