Skip to content

Commit 190898a

Browse files
authored
Merge pull request #30 from wwoodsTM/test-dry-sampler
Working implementation of DRY with one key issue I could use help with
2 parents ed6b909 + a18fb2f commit 190898a

File tree

6 files changed

+208
-54
lines changed

6 files changed

+208
-54
lines changed

common/sampling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
433433
{
434434
const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n);
435435
if (penalty_tokens_used_size) {
436-
llama_sample_dry(&cur_p,
436+
llama_sample_dry(ctx_main, &cur_p,
437437
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
438438
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
439-
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
439+
params.dry_seq_breakers);
440440
}
441441
}
442442

common/sampling.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ typedef struct llama_sampling_params {
4646
uint32_t dry_allowed_length = 2;
4747
int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
4848

49+
std::vector<std::string> dry_seq_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
50+
4951
std::vector<llama_sampler_type> samplers_sequence = {
5052
llama_sampler_type::TOP_K,
5153
llama_sampler_type::TFS_Z,
@@ -63,9 +65,8 @@ typedef struct llama_sampling_params {
6365
float cfg_scale = 1.f; // how strong is guidance
6466

6567
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
66-
6768
std::vector<llama_token> penalty_prompt_tokens;
68-
std::vector<llama_token> dry_seq_breakers; // sequence breakers for the DRY sampler
69+
6970
bool use_penalty_prompt_tokens = false;
7071
} llama_sampling_params;
7172

include/llama.h

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,16 +1085,17 @@ extern "C" {
10851085
float p,
10861086
size_t min_keep);
10871087

1088-
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
1089-
LLAMA_API void llama_sample_dry(
1090-
llama_token_data_array * candidates,
1091-
const llama_token * last_tokens,
1092-
size_t last_tokens_size,
1093-
float dry_base,
1094-
float dry_multiplier,
1095-
int dry_allowed_length,
1096-
const llama_token * dry_seq_breakers,
1097-
size_t dry_seq_breakers_size);
1088+
// /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
1089+
// LLAMA_API void llama_sample_dry(
1090+
// struct llama_context * ctx,
1091+
// llama_token_data_array * candidates,
1092+
// const llama_token * last_tokens,
1093+
// size_t last_tokens_size,
1094+
// float dry_base,
1095+
// float dry_multiplier,
1096+
// int dry_allowed_length,
1097+
// const std::vector<std::string>
1098+
// & dry_seq_breakers);
10981099

10991100
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
11001101
LLAMA_API void llama_sample_tail_free(
@@ -1246,6 +1247,18 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
12461247
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
12471248
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
12481249

1250+
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
1251+
LLAMA_API void llama_sample_dry(
1252+
struct llama_context * ctx,
1253+
llama_token_data_array * candidates,
1254+
const llama_token * last_tokens,
1255+
size_t last_tokens_size,
1256+
float dry_base,
1257+
float dry_multiplier,
1258+
int dry_allowed_length,
1259+
const std::vector<std::string>
1260+
& dry_seq_breakers);
1261+
12491262
#endif // LLAMA_API_INTERNAL
12501263

12511264
#endif // LLAMA_H

src/llama-sampling.cpp

Lines changed: 173 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -232,94 +232,230 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
232232
}
233233
}
234234

235-
void llama_sample_dry_impl(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
236-
// skip dry sampler if we don't have a previous token
237-
if (last_tokens_size < 1) return;
235+
std::vector<llama_token> llama_tokenize(
236+
const struct llama_context * ctx,
237+
const std::string & text,
238+
bool add_special,
239+
bool parse_special) {
240+
return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special);
241+
}
242+
243+
std::vector<llama_token> llama_tokenize(
244+
const struct llama_model * model,
245+
const std::string & text,
246+
bool add_special,
247+
bool parse_special) {
248+
// upper limit for the number of tokens
249+
int n_tokens = text.length() + 2 * add_special;
250+
std::vector<llama_token> result(n_tokens);
251+
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
252+
if (n_tokens < 0) {
253+
result.resize(-n_tokens);
254+
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
255+
GGML_ASSERT(check == -n_tokens);
256+
} else {
257+
result.resize(n_tokens);
258+
}
259+
return result;
260+
}
261+
262+
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
263+
std::string text;
264+
text.resize(std::max(text.capacity(), tokens.size()));
265+
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
266+
if (n_chars < 0) {
267+
text.resize(-n_chars);
268+
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
269+
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
270+
}
271+
272+
text.resize(n_chars);
273+
274+
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
275+
return text;
276+
}
277+
278+
std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special) {
279+
std::vector<llama_token> tokens = {token};
280+
return llama_detokenize(ctx, tokens, special);
281+
}
238282

239-
// get the last token
240-
auto last_token = last_tokens[last_tokens_size - 1];
283+
// Constants for preventing overflow
284+
const float FLOAT_MAX_LOG = 88.7228391f;
285+
const int MAX_CHAR_LEN = 40;
286+
const int MAX_SEQ_LEN = 20;
241287

242-
// if last token is part of the sequence breakers, skip whole sampler
243-
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {
288+
289+
void llama_sample_dry_impl(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers) {
290+
if (last_tokens_size < 1) {
244291
return;
245292
}
246293

247-
// create an unordered map of "next tokens" <-> max match length
294+
// Cache for token-to-string conversions
295+
std::unordered_map<llama_token, std::string> token_to_string_cache;
296+
// Store sequence breakers for more efficient lookup
297+
std::unordered_multimap<std::string, std::vector<std::string>> restart_sequences;
298+
299+
auto detokenize_with_cache = [&](llama_token token) -> std::string {
300+
auto it = token_to_string_cache.find(token);
301+
if (it != token_to_string_cache.end()) {
302+
return it->second;
303+
}
304+
std::string token_str = llama_detokenize_single(ctx, token, false);
305+
token_to_string_cache[token] = token_str;
306+
return token_str;
307+
};
308+
309+
// Pre-process dry_seq_breakers
310+
for (const auto& breaker : dry_seq_breakers) {
311+
std::string breaker_trimmed = breaker.substr(0, MAX_CHAR_LEN);
312+
std::vector<llama_token> tokens = llama_tokenize(ctx, breaker_trimmed, false, false);
313+
314+
if (!tokens.empty()) {
315+
std::string head = detokenize_with_cache(tokens[0]);
316+
std::vector<std::string> tail;
317+
318+
for (size_t i = 1; i < tokens.size() && i <= MAX_SEQ_LEN; ++i) {
319+
tail.push_back(detokenize_with_cache(tokens[i]));
320+
}
321+
restart_sequences.emplace(head, tail);
322+
}
323+
}
324+
325+
// Find max repetition length considering restart sequences
326+
int rep_limit = last_tokens_size;
327+
328+
for (size_t i = 0; i < last_tokens_size; ++i) {
329+
size_t ix = last_tokens_size - 1 - i;
330+
std::string token_str = detokenize_with_cache(last_tokens[ix]);
331+
332+
// Check if the token is a potential sequence breaker
333+
auto its = restart_sequences.equal_range(token_str);
334+
if (its.first == restart_sequences.end()) continue;
335+
336+
int longest_match = -1;
337+
// Check all potential sequence breakers starting with this token
338+
for (auto it = its.first; it != its.second; ++it) {
339+
int seq_len = (int)it->second.size();
340+
if (seq_len > longest_match && seq_len <= i) {
341+
bool match = true;
342+
// Check if the following tokens match the sequence breaker
343+
for (size_t offset = 0; offset < seq_len; ++offset) {
344+
if (it->second[offset] != detokenize_with_cache(last_tokens[ix + 1 + offset])) {
345+
match = false;
346+
break;
347+
}
348+
}
349+
if (match) {
350+
longest_match = seq_len;
351+
}
352+
}
353+
}
354+
355+
if (longest_match >= 0) {
356+
rep_limit = static_cast<int>(i) - longest_match;
357+
break;
358+
}
359+
}
360+
361+
if (rep_limit <= dry_allowed_length) {
362+
return;
363+
}
364+
365+
// Store max match length for each token
248366
std::unordered_map<llama_token, size_t> match_lengths;
249367

250-
// loop through each previous token (exclude the last token)
368+
// Find repeated sequences
251369
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
252-
// skip if the compare token is not the same as the last token
253-
if (last_tokens[i] != last_token) {
370+
if (last_tokens[i] != last_tokens[last_tokens_size - 1]) {
254371
continue;
255372
}
256373

257-
// get the next token (i + 1 is always less than last_tokens_size)
258374
auto next_token = last_tokens[i + 1];
375+
std::string next_token_str = detokenize_with_cache(next_token);
259376

260-
// if next token is part of the sequence breakers, skip
261-
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
377+
// Skip if next token is a sequence breaker
378+
auto its = restart_sequences.equal_range(next_token_str);
379+
if (its.first != restart_sequences.end()) {
262380
continue;
263381
}
264382

265-
// try to extend the match backwards (match length starts at 1 because last token is already matched)
266383
size_t match_length = 1;
267384

268-
// loop through the previous tokens
385+
// Extend match as far as possible
269386
for (;; match_length++) {
270-
// if we have reached the start of our last tokens, break
271-
if (i < match_length) break;
387+
if (i < match_length || match_length > rep_limit) {
388+
break;
389+
}
272390

273-
// compare token starts at our prev index, going backwards by match length
274391
auto compare_token = last_tokens[i - match_length];
392+
std::string compare_token_str = detokenize_with_cache(compare_token);
275393

276-
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
277394
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
395+
std::string head_token_str = detokenize_with_cache(head_token);
278396

279-
// break out of the match if any tokens don't match
280-
if (compare_token != head_token) {
397+
if (compare_token_str != head_token_str) {
281398
break;
282399
}
283400

284-
// if compare token is part of the sequence breakers, break out of the match
285-
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
401+
// Check if we've hit a sequence breaker
402+
its = restart_sequences.equal_range(compare_token_str);
403+
if (its.first != restart_sequences.end()) {
286404
break;
287405
}
288406
}
289407

290-
// Check if the next token exists in the map
408+
// Update max match length for this token
291409
auto it = match_lengths.find(next_token);
292-
293410
if (it == match_lengths.end()) {
294-
// Key does not exist, insert the new value
295411
match_lengths[next_token] = match_length;
296412
} else {
297-
// Key exists, update it with the max of the new value or the existing value
298413
it->second = std::max(it->second, match_length);
299414
}
300415
}
301416

302-
// apply penalties
417+
// Calculate max safe exponent
418+
int max_exponent = 0;
419+
if (dry_base > 1.000001f) {
420+
max_exponent = static_cast<int>(FLOAT_MAX_LOG / log(dry_base));
421+
}
422+
423+
#ifdef DEBUG
424+
LLAMA_LOG_INFO("DRY Sampling parameters:\n");
425+
LLAMA_LOG_INFO(" dry_base: %f\n", dry_base);
426+
LLAMA_LOG_INFO(" dry_multiplier: %f\n", dry_multiplier);
427+
LLAMA_LOG_INFO(" dry_allowed_length: %d\n", dry_allowed_length);
428+
LLAMA_LOG_INFO(" max_exponent: %d\n", max_exponent);
429+
LLAMA_LOG_INFO("DRY penalties [");
430+
#endif
431+
432+
// Apply penalties
303433
for (const auto& pair : match_lengths) {
304434
auto next_token = pair.first;
305435
auto match_length = pair.second;
306436

307-
// if the match length is greater than or equal to our allowed length in config, we apply penalities
308-
if (match_length >= (size_t)dry_allowed_length) {
309-
310-
// find our next token in the candidates->data
437+
if (match_length >= static_cast<size_t>(dry_allowed_length)) {
311438
for (size_t i = 0; i < candidates->size; ++i) {
312439
if (candidates->data[i].id == next_token) {
313-
// calculate the penalty
314-
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
315-
316-
// apply the dry penalty
440+
int repeat_exp = static_cast<int>(match_length - dry_allowed_length);
441+
if (max_exponent > 0 && repeat_exp > max_exponent) {
442+
repeat_exp = max_exponent;
443+
}
444+
float penalty = dry_multiplier * pow(dry_base, static_cast<float>(repeat_exp));
317445
candidates->data[i].logit -= penalty;
446+
447+
#ifdef DEBUG
448+
LLAMA_LOG_INFO(" Token %d: %s (Penalty: %.2f)", next_token, detokenize_with_cache(next_token).c_str(), penalty);
449+
#endif
318450
break;
319451
}
320452
}
321453
}
322454
}
455+
456+
#ifdef DEBUG
457+
LLAMA_LOG_INFO("]\n");
458+
#endif
323459
}
324460

325461
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {

src/llama-sampling.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_
2828
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
2929
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
3030
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
31-
void llama_sample_dry_impl (llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size);
31+
std::vector<llama_token> llama_tokenize(const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special);
32+
std::vector<llama_token> llama_tokenize(const struct llama_model * model, const std::string & text, bool add_special, bool parse_special);
33+
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special);
34+
std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special);
35+
void llama_sample_dry_impl (struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers);
3236
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
3337
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
3438
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);

src/llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18948,8 +18948,8 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
1894818948
llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
1894918949
}
1895018950

18951-
void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
18952-
llama_sample_dry_impl(candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers, dry_seq_breakers_size);
18951+
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers) {
18952+
llama_sample_dry_impl(ctx, candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers);
1895318953
}
1895418954

1895518955
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {

0 commit comments

Comments
 (0)