Skip to content

Commit bf59e42

Browse files
committed
added implementation of DRY sampler
1 parent 4e96a81 commit bf59e42

File tree

4 files changed

+91
-2
lines changed

4 files changed

+91
-2
lines changed

common/sampling.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,18 @@ static llama_token_data_array llama_sampling_prepare_impl(
256256

257257
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
258258

259+
// repetition penalties
259260
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
260261
const float penalty_repeat = params.penalty_repeat;
261262
const float penalty_freq = params.penalty_freq;
262263
const float penalty_present = params.penalty_present;
263-
264264
const bool penalize_nl = params.penalize_nl;
265265

266+
// DRY sampler parameters
267+
const float dry_multiplier = params.dry_multiplier;
268+
const float dry_base = params.dry_base;
269+
const int dry_allowed_length = params.dry_allowed_length;
270+
266271
auto & prev = ctx_sampling->prev;
267272
auto & cur = ctx_sampling->cur;
268273

@@ -298,10 +303,20 @@ static llama_token_data_array llama_sampling_prepare_impl(
298303
if (penalty_tokens_used_size) {
299304
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
300305

306+
// repetition penalties
301307
llama_sample_repetition_penalties(ctx_main, &cur_p,
302308
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
303309
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
304310

311+
// DRY penalties (multiplier > 0 means enabled)
312+
if(dry_multiplier > 0.0f) {
313+
llama_sample_dry(ctx_main, &cur_p,
314+
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
315+
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
316+
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
317+
}
318+
319+
305320
if (!penalize_nl) {
306321
for (size_t idx = 0; idx < cur_p.size; idx++) {
307322
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {

common/sampling.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ typedef struct llama_sampling_params {
3838
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
3939
float mirostat_tau = 5.00f; // target entropy
4040
float mirostat_eta = 0.10f; // learning rate
41-
bool penalize_nl = false; // consider newlines as a repeatable token
41+
bool penalize_nl = false; // consider newlines as a repeatable token
42+
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
43+
float dry_base = 1.75f;
44+
int dry_allowed_length = 2;
4245

4346
std::vector<llama_sampler_type> samplers_sequence = {
4447
llama_sampler_type::TOP_K,
@@ -59,6 +62,7 @@ typedef struct llama_sampling_params {
5962
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
6063

6164
std::vector<llama_token> penalty_prompt_tokens;
65+
std::vector<llama_token> dry_sequence_breakers; // sequence breakers for the DRY sampler
6266
bool use_penalty_prompt_tokens = false;
6367
} llama_sampling_params;
6468

llama.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13044,6 +13044,64 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
1304413044
}
1304513045
}
1304613046

13047+
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
13048+
// loop through each candidate
13049+
for (size_t i = 0; i < candidates->size; ++i) {
13050+
13051+
// if our candidate itself is part of the sequence breakers, we don't apply the dry penalty
13052+
if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) {
13053+
continue;
13054+
}
13055+
13056+
int max_match_length = 0;
13057+
13058+
// loop through each previous token
13059+
for (size_t j = 0; j < last_token_size; ++j) {
13060+
// if the current candidate is the same as the previous token
13061+
if (candidates->data[i].id == last_tokens[j]) {
13062+
// greedily match sequence backwards starting from the current position with the end of prev
13063+
int match_length = 1;
13064+
13065+
// loop through the previous tokens
13066+
for(;; match_length++) {
13067+
// if we have reached the start of our stored prev, break
13068+
if(j - match_length > 0) break;
13069+
13070+
// this shouldn't happen because (j - match_length) should always be smaller than (size - match_length)
13071+
// but let's check here to avoid the unexpected
13072+
if(last_token_size - match_length < 0) break;
13073+
13074+
// compare token starts at our prev index, going backwards by match length
13075+
auto compare_token = last_tokens[j - match_length];
13076+
13077+
// head token starts at the end of prev, going backwards by match length
13078+
auto head_token = last_tokens[last_token_size - match_length];
13079+
13080+
// if compare token is part of the sequence breakers, break out of the match
13081+
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
13082+
break;
13083+
13084+
// break out of the match if any tokens don't match
13085+
if(compare_token != head_token)
13086+
break;
13087+
}
13088+
13089+
// update our max match length
13090+
max_match_length = std::max(max_match_length, match_length);
13091+
}
13092+
}
13093+
13094+
// apply penalties
13095+
if(max_match_length > dry_allowed_length) {
13096+
// calculate the penalty
13097+
float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length);
13098+
13099+
// apply the dry penalty
13100+
candidates->data[i].logit -= penalty;
13101+
}
13102+
}
13103+
}
13104+
1304713105
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
1304813106
if (z >= 1.0f || candidates->size <= 2) {
1304913107
return;

llama.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,18 @@ extern "C" {
922922
float p,
923923
size_t min_keep);
924924

925+
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
926+
LLAMA_API void llama_sample_dry(
927+
struct llama_context * ctx,
928+
llama_token_data_array * candidates,
929+
const llama_token * last_tokens,
930+
int last_token_size,
931+
float dry_base,
932+
float dry_multiplier,
933+
int dry_allowed_length,
934+
const llama_token * seq_breakers,
935+
int seq_breakers_size);
936+
925937
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
926938
LLAMA_API void llama_sample_tail_free(
927939
struct llama_context * ctx,

0 commit comments

Comments
 (0)