Skip to content

Commit d8b47da

Browse files
committed
refactored sample_dry to be more efficient and closer to original implementation
1 parent babff06 commit d8b47da

File tree

1 file changed

+59
-37
lines changed

1 file changed

+59
-37
lines changed

llama.cpp

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13045,58 +13045,80 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
1304513045
}
1304613046

1304713047
void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, size_t seq_breakers_size) {
13048-
// loop through each candidate
13049-
for (size_t i = 0; i < candidates->size; ++i) {
13048+
// sanity check
13049+
GGML_ASSERT(last_tokens_size > 0);
13050+
13051+
// get the last token
13052+
auto last_token = last_tokens[last_tokens_size - 1];
13053+
13054+
// create an unordered map of "next tokens" <-> max match length
13055+
std::unordered_map<llama_token, size_t> match_lengths;
1305013056

13051-
// if our candidate itself is part of the sequence breakers, we don't apply the dry penalty
13052-
if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) {
13057+
// loop through each previous token (exclude the last token)
13058+
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
13059+
// skip if the compare token if it's not the same as the last token
13060+
if(last_tokens[i] != last_token) {
1305313061
continue;
1305413062
}
1305513063

13056-
int max_match_length = 0;
13064+
// get the next token (i + 1 is always less than last_tokens_size)
13065+
auto next_token = last_tokens[i + 1];
1305713066

13058-
// loop through each previous token
13059-
for (size_t j = 0; j < last_tokens_size; ++j) {
13060-
// if the current candidate is the same as the previous token
13061-
if (candidates->data[i].id == last_tokens[j]) {
13062-
// greedily match sequence backwards starting from the current position with the end of prev
13063-
int match_length = 1;
13067+
// try to extend the match backwards (match length starts a 1 because last token is already matched)
13068+
size_t match_length = 1;
1306413069

13065-
// loop through the previous tokens
13066-
for(;; match_length++) {
13067-
// if we have reached the start of our stored prev, break
13068-
if(j - match_length > 0) break;
13070+
// loop through the previous tokens
13071+
for(;; match_length++) {
13072+
// if we have reached the start of our last tokens, break
13073+
if(i < match_length) break;
1306913074

13070-
// (last_tokens_size - match_length) is unsigned so will always be greater or equal to 0
13071-
// so no need to check for index out of bound here
13075+
// compare token starts at our prev index, going backwards by match length
13076+
auto compare_token = last_tokens[i - match_length];
1307213077

13073-
// compare token starts at our prev index, going backwards by match length
13074-
auto compare_token = last_tokens[j - match_length];
13078+
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
13079+
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
1307513080

13076-
// head token starts at the end of prev, going backwards by match length
13077-
auto head_token = last_tokens[last_tokens_size - match_length];
13081+
// if compare token is part of the sequence breakers, break out of the match
13082+
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
13083+
break;
1307813084

13079-
// if compare token is part of the sequence breakers, break out of the match
13080-
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
13081-
break;
13085+
// break out of the match if any tokens don't match
13086+
if(compare_token != head_token)
13087+
break;
13088+
}
1308213089

13083-
// break out of the match if any tokens don't match
13084-
if(compare_token != head_token)
13085-
break;
13086-
}
13090+
// Check if the next token exists in the map
13091+
auto it = match_lengths.find(next_token);
1308713092

13088-
// update our max match length
13089-
max_match_length = std::max(max_match_length, match_length);
13090-
}
13093+
if (it == match_lengths.end()) {
13094+
// Key does not exist, insert the new value
13095+
match_lengths[next_token] = match_length;
13096+
} else {
13097+
// Key exists, update it with the max of the new value or the existing value
13098+
it->second = std::max(it->second, match_length);
1309113099
}
13100+
}
13101+
13102+
// apply penalties
13103+
for (const auto& pair : match_lengths) {
13104+
auto next_token = pair.first;
13105+
auto match_length = pair.second;
1309213106

13093-
// apply penalties
13094-
if(max_match_length > dry_allowed_length) {
13095-
// calculate the penalty
13096-
float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length);
13107+
// if the match length is greater than our allowed length in config, we apply penalities
13108+
if(match_length > dry_allowed_length) {
1309713109

13098-
// apply the dry penalty
13099-
candidates->data[i].logit -= penalty;
13110+
// find our next token in the candidates->data
13111+
size_t i = 0;
13112+
for (; i < candidates->size; ++i) {
13113+
if (candidates->data[i].id == next_token) {
13114+
// calculate the penalty
13115+
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
13116+
13117+
// apply the dry penalty
13118+
candidates->data[i].logit -= penalty;
13119+
break;
13120+
}
13121+
}
1310013122
}
1310113123
}
1310213124
}

0 commit comments

Comments
 (0)