@@ -13045,58 +13045,80 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
13045
13045
}
13046
13046
13047
13047
void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, size_t seq_breakers_size) {
13048
- // loop through each candidate
13049
- for (size_t i = 0; i < candidates->size; ++i) {
13048
+ // sanity check
13049
+ GGML_ASSERT(last_tokens_size > 0);
13050
+
13051
+ // get the last token
13052
+ auto last_token = last_tokens[last_tokens_size - 1];
13053
+
13054
+ // create an unordered map of "next tokens" <-> max match length
13055
+ std::unordered_map<llama_token, size_t> match_lengths;
13050
13056
13051
- // if our candidate itself is part of the sequence breakers, we don't apply the dry penalty
13052
- if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) {
13057
+ // loop through each previous token (exclude the last token)
13058
+ for (size_t i = 0; i < last_tokens_size - 1; ++i) {
13059
+ // skip if the compare token if it's not the same as the last token
13060
+ if(last_tokens[i] != last_token) {
13053
13061
continue;
13054
13062
}
13055
13063
13056
- int max_match_length = 0;
13064
+ // get the next token (i + 1 is always less than last_tokens_size)
13065
+ auto next_token = last_tokens[i + 1];
13057
13066
13058
- // loop through each previous token
13059
- for (size_t j = 0; j < last_tokens_size; ++j) {
13060
- // if the current candidate is the same as the previous token
13061
- if (candidates->data[i].id == last_tokens[j]) {
13062
- // greedily match sequence backwards starting from the current position with the end of prev
13063
- int match_length = 1;
13067
+ // try to extend the match backwards (match length starts a 1 because last token is already matched)
13068
+ size_t match_length = 1;
13064
13069
13065
- // loop through the previous tokens
13066
- for(;; match_length++) {
13067
- // if we have reached the start of our stored prev , break
13068
- if(j - match_length > 0 ) break;
13070
+ // loop through the previous tokens
13071
+ for(;; match_length++) {
13072
+ // if we have reached the start of our last tokens , break
13073
+ if(i < match_length) break;
13069
13074
13070
- // (last_tokens_size - match_length) is unsigned so will always be greater or equal to 0
13071
- // so no need to check for index out of bound here
13075
+ // compare token starts at our prev index, going backwards by match length
13076
+ auto compare_token = last_tokens[i - match_length];
13072
13077
13073
- // compare token starts at our prev index , going backwards by match length
13074
- auto compare_token = last_tokens[j - match_length];
13078
+ // head token starts at the end of last tokens , going backwards by match length, minus 1 because we start at the last token itself
13079
+ auto head_token = last_tokens[last_tokens_size - 1 - match_length];
13075
13080
13076
- // head token starts at the end of prev, going backwards by match length
13077
- auto head_token = last_tokens[last_tokens_size - match_length];
13081
+ // if compare token is part of the sequence breakers, break out of the match
13082
+ if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
13083
+ break;
13078
13084
13079
- // if compare token is part of the sequence breakers, break out of the match
13080
- if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
13081
- break;
13085
+ // break out of the match if any tokens don't match
13086
+ if(compare_token != head_token)
13087
+ break;
13088
+ }
13082
13089
13083
- // break out of the match if any tokens don't match
13084
- if(compare_token != head_token)
13085
- break;
13086
- }
13090
+ // Check if the next token exists in the map
13091
+ auto it = match_lengths.find(next_token);
13087
13092
13088
- // update our max match length
13089
- max_match_length = std::max(max_match_length, match_length);
13090
- }
13093
+ if (it == match_lengths.end()) {
13094
+ // Key does not exist, insert the new value
13095
+ match_lengths[next_token] = match_length;
13096
+ } else {
13097
+ // Key exists, update it with the max of the new value or the existing value
13098
+ it->second = std::max(it->second, match_length);
13091
13099
}
13100
+ }
13101
+
13102
+ // apply penalties
13103
+ for (const auto& pair : match_lengths) {
13104
+ auto next_token = pair.first;
13105
+ auto match_length = pair.second;
13092
13106
13093
- // apply penalties
13094
- if(max_match_length > dry_allowed_length) {
13095
- // calculate the penalty
13096
- float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length);
13107
+ // if the match length is greater than our allowed length in config, we apply penalities
13108
+ if(match_length > dry_allowed_length) {
13097
13109
13098
- // apply the dry penalty
13099
- candidates->data[i].logit -= penalty;
13110
+ // find our next token in the candidates->data
13111
+ size_t i = 0;
13112
+ for (; i < candidates->size; ++i) {
13113
+ if (candidates->data[i].id == next_token) {
13114
+ // calculate the penalty
13115
+ float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
13116
+
13117
+ // apply the dry penalty
13118
+ candidates->data[i].logit -= penalty;
13119
+ break;
13120
+ }
13121
+ }
13100
13122
}
13101
13123
}
13102
13124
}
0 commit comments