4
4
5
5
#include < string>
6
6
#include < vector>
7
- #include < stdexcept>
8
7
9
8
// sampler types
10
9
enum class llama_sampler_type : char {
@@ -59,119 +58,16 @@ typedef struct gpt_sampling_params {
59
58
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
60
59
} gpt_sampling_params;
61
60
62
- // the ring buffer works similarly to std::deque, but with a fixed capacity
63
- template <typename T>
64
- struct ring_buffer {
65
- ring_buffer () {}
66
- ring_buffer (size_t cap) : capacity(cap), data(cap) {}
67
-
68
- T & front () {
69
- if (sz == 0 ) {
70
- throw std::runtime_error (" ring buffer is empty" );
71
- }
72
- return data[first];
73
- }
74
-
75
- const T & front () const {
76
- if (sz == 0 ) {
77
- throw std::runtime_error (" ring buffer is empty" );
78
- }
79
- return data[first];
80
- }
81
-
82
- T & back () {
83
- if (sz == 0 ) {
84
- throw std::runtime_error (" ring buffer is empty" );
85
- }
86
- return data[pos];
87
- }
88
-
89
- const T & back () const {
90
- if (sz == 0 ) {
91
- throw std::runtime_error (" ring buffer is empty" );
92
- }
93
- return data[pos];
94
- }
95
-
96
- void push_back (const T & value) {
97
- if (sz == capacity) {
98
- // advance the start when buffer is full
99
- first = (first + 1 ) % capacity;
100
- } else {
101
- sz++;
102
- }
103
- data[pos] = value;
104
- pos = (pos + 1 ) % capacity;
105
- }
106
-
107
- T pop_front () {
108
- if (sz == 0 ) {
109
- throw std::runtime_error (" ring buffer is empty" );
110
- }
111
- T value = data[first];
112
- first = (first + 1 ) % capacity;
113
- sz--;
114
- return value;
115
- }
116
-
117
- T & operator [](size_t i) {
118
- if (i >= sz) {
119
- throw std::runtime_error (" ring buffer: index out of bounds" );
120
- }
121
- return data[(first + i) % capacity];
122
- }
123
-
124
- const T & operator [](size_t i) const {
125
- if (i >= sz) {
126
- throw std::runtime_error (" ring buffer: index out of bounds" );
127
- }
128
- return data[(first + i) % capacity];
129
- }
130
-
131
- std::vector<T> to_vector () const {
132
- std::vector<T> result;
133
- result.reserve (sz);
134
- for (size_t i = 0 ; i < sz; i++) {
135
- result.push_back (data[(first + i) % capacity]);
136
- }
137
- return result;
138
- }
139
-
140
- void clear () {
141
- // here only reset the status of the buffer
142
- sz = 0 ;
143
- first = 0 ;
144
- pos = 0 ;
145
- }
146
-
147
- bool empty () const {
148
- return sz == 0 ;
149
- }
150
-
151
- size_t size () const {
152
- return sz;
153
- }
154
-
155
- size_t capacity = 0 ;
156
- size_t sz = 0 ;
157
- size_t first = 0 ;
158
- size_t pos = 0 ;
159
- std::vector<T> data;
160
- };
161
-
162
61
// general sampler context
163
62
// TODO: move to llama.h
164
63
struct llama_sampling_context {
165
64
// parameters that will be used for sampling
166
65
gpt_sampling_params params;
167
66
168
- // mirostat sampler state
169
- float mirostat_mu;
170
-
171
67
llama_sampling * smpl;
172
68
173
- ring_buffer<llama_token> prev;
174
69
std::vector<llama_token_data> cur;
70
+ std::vector<llama_token_data> org;
175
71
176
72
size_t n_valid; // Number of correct top tokens with correct probabilities.
177
73
};
@@ -189,10 +85,10 @@ void llama_sampling_reset(llama_sampling_context * ctx);
189
85
// Copy the sampler context
190
86
void llama_sampling_cp (llama_sampling_context * src, llama_sampling_context * dst);
191
87
192
- // Get the last sampled token
88
+ // Get the last accepted token
193
89
llama_token llama_sampling_last (llama_sampling_context * ctx);
194
90
195
- // Get a string representation of the last sampled tokens
91
+ // Get a string representation of the last accepted tokens
196
92
std::string llama_sampling_prev_str (llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
197
93
198
94
// Print sampling parameters into a string
@@ -206,6 +102,13 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
206
102
std::vector<llama_sampler_type> llama_sampling_types_from_names (const std::vector<std::string> & names, bool allow_alt_names);
207
103
std::vector<llama_sampler_type> llama_sampling_types_from_chars (const std::string & names_string);
208
104
105
+ // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
106
+ llama_token_data_array llama_sampling_prepare (
107
+ struct llama_sampling_context * ctx_sampling,
108
+ struct llama_context * ctx_main,
109
+ struct llama_context * ctx_cfg,
110
+ int idx = 0 );
111
+
209
112
// this is a common sampling function used across the examples for convenience
210
113
// it can serve as a starting point for implementing your own sampling function
211
114
// Note: When using multiple sequences, it is the caller's responsibility to call
@@ -223,20 +126,15 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
223
126
// - token: sampled token
224
127
// - candidates: vector of candidate tokens
225
128
//
226
- llama_token llama_sampling_sample (
227
- struct llama_sampling_context * ctx_sampling,
228
- struct llama_context * ctx_main,
229
- struct llama_context * ctx_cfg,
230
- int idx = -1 );
129
+ // llama_token llama_sampling_sample(
130
+ // struct llama_sampling_context * ctx_sampling,
131
+ // struct llama_token_data_array * cur_p);
231
132
232
- // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
233
- llama_token_data_array llama_sampling_prepare (
133
+ llama_token llama_sampling_sample (
234
134
struct llama_sampling_context * ctx_sampling,
235
135
struct llama_context * ctx_main,
236
136
struct llama_context * ctx_cfg,
237
- int idx = 0 ,
238
- bool apply_grammar = true ,
239
- std::vector<float > * original_logits = nullptr );
137
+ int idx = 0 );
240
138
241
139
void llama_sampling_accept (
242
140
struct llama_sampling_context * ctx_sampling,
0 commit comments