Skip to content

Commit 99b7760

Browse files
committed
added dry sampler implementatin
1 parent b03b419 commit 99b7760

File tree

4 files changed

+91
-2
lines changed

4 files changed

+91
-2
lines changed

common/sampling.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,18 @@ static llama_token_data_array llama_sampling_prepare_impl(
260260

261261
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
262262

263+
// repetition penalties
263264
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
264265
const float penalty_repeat = params.penalty_repeat;
265266
const float penalty_freq = params.penalty_freq;
266267
const float penalty_present = params.penalty_present;
267-
268268
const bool penalize_nl = params.penalize_nl;
269269

270+
// DRY sampler parameters
271+
const float dry_multiplier = params.dry_multiplier;
272+
const float dry_base = params.dry_base;
273+
const int dry_allowed_length = params.dry_allowed_length;
274+
270275
auto & prev = ctx_sampling->prev;
271276
auto & cur = ctx_sampling->cur;
272277

@@ -302,10 +307,20 @@ static llama_token_data_array llama_sampling_prepare_impl(
302307
if (penalty_tokens_used_size) {
303308
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
304309

310+
// repetition penalties
305311
llama_sample_repetition_penalties(ctx_main, &cur_p,
306312
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
307313
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
308314

315+
// DRY penalties (multiplier > 0 means enabled)
316+
if(dry_multiplier > 0.0f) {
317+
llama_sample_dry(ctx_main, &cur_p,
318+
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
319+
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
320+
params.dry_sequence_breakers.data(), params.dry_sequence_breakers.size());
321+
}
322+
323+
309324
if (!penalize_nl) {
310325
for (size_t idx = 0; idx < cur_p.size; idx++) {
311326
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {

common/sampling.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ typedef struct llama_sampling_params {
3838
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
3939
float mirostat_tau = 5.00f; // target entropy
4040
float mirostat_eta = 0.10f; // learning rate
41-
bool penalize_nl = false; // consider newlines as a repeatable token
41+
bool penalize_nl = false; // consider newlines as a repeatable token
42+
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
43+
float dry_base = 1.75f;
44+
int dry_allowed_length = 2;
4245

4346
std::vector<llama_sampler_type> samplers_sequence = {
4447
llama_sampler_type::TOP_K,
@@ -59,6 +62,7 @@ typedef struct llama_sampling_params {
5962
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
6063

6164
std::vector<llama_token> penalty_prompt_tokens;
65+
std::vector<llama_token> dry_sequence_breakers; // sequence breakers for the DRY sampler
6266
bool use_penalty_prompt_tokens = false;
6367
} llama_sampling_params;
6468

llama.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12832,6 +12832,64 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
1283212832
}
1283312833
}
1283412834

12835+
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
12836+
// loop through each candidate
12837+
for (size_t i = 0; i < candidates->size; ++i) {
12838+
12839+
// if our candidate itself is part of the sequence breakers, we don't apply the dry penalty
12840+
if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) {
12841+
continue;
12842+
}
12843+
12844+
int max_match_length = 0;
12845+
12846+
// loop through each previous token
12847+
for (size_t j = 0; j < last_token_size; ++j) {
12848+
// if the current candidate is the same as the previous token
12849+
if (candidates->data[i].id == last_tokens[j]) {
12850+
// greedily match sequence backwards starting from the current position with the end of prev
12851+
int match_length = 1;
12852+
12853+
// loop through the previous tokens
12854+
for(;; match_length++) {
12855+
// if we have reached the start of our stored prev, break
12856+
if(j - match_length > 0) break;
12857+
12858+
// this shouldn't happen because (j - match_length) should always be smaller than (size - match_length)
12859+
// but let's check here to avoid the unexpected
12860+
if(last_token_size - match_length < 0) break;
12861+
12862+
// compare token starts at our prev index, going backwards by match length
12863+
auto compare_token = last_tokens[j - match_length];
12864+
12865+
// head token starts at the end of prev, going backwards by match length
12866+
auto head_token = last_tokens[last_token_size - match_length];
12867+
12868+
// if compare token is part of the sequence breakers, break out of the match
12869+
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
12870+
break;
12871+
12872+
// break out of the match if any tokens don't match
12873+
if(compare_token != head_token)
12874+
break;
12875+
}
12876+
12877+
// update our max match length
12878+
max_match_length = std::max(max_match_length, match_length);
12879+
}
12880+
}
12881+
12882+
// apply penalties
12883+
if(max_match_length > dry_allowed_length) {
12884+
// calculate the penalty
12885+
float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length);
12886+
12887+
// apply the dry penalty
12888+
candidates->data[i].logit -= penalty;
12889+
}
12890+
}
12891+
}
12892+
1283512893
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
1283612894
if (z >= 1.0f || candidates->size <= 2) {
1283712895
return;

llama.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,18 @@ extern "C" {
918918
float p,
919919
size_t min_keep);
920920

921+
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
922+
LLAMA_API void llama_sample_dry(
923+
struct llama_context * ctx,
924+
llama_token_data_array * candidates,
925+
const llama_token * last_tokens,
926+
int last_token_size,
927+
float dry_base,
928+
float dry_multiplier,
929+
int dry_allowed_length,
930+
const llama_token * seq_breakers,
931+
int seq_breakers_size);
932+
921933
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
922934
LLAMA_API void llama_sample_tail_free(
923935
struct llama_context * ctx,

0 commit comments

Comments
 (0)