@@ -232,94 +232,230 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
232
232
}
233
233
}
234
234
235
- void llama_sample_dry_impl (llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
236
- // skip dry sampler if we don't have a previous token
237
- if (last_tokens_size < 1 ) return ;
235
+ std::vector<llama_token> llama_tokenize (
236
+ const struct llama_context * ctx,
237
+ const std::string & text,
238
+ bool add_special,
239
+ bool parse_special) {
240
+ return llama_tokenize (llama_get_model (ctx), text, add_special, parse_special);
241
+ }
242
+
243
+ std::vector<llama_token> llama_tokenize (
244
+ const struct llama_model * model,
245
+ const std::string & text,
246
+ bool add_special,
247
+ bool parse_special) {
248
+ // upper limit for the number of tokens
249
+ int n_tokens = text.length () + 2 * add_special;
250
+ std::vector<llama_token> result (n_tokens);
251
+ n_tokens = llama_tokenize (model, text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
252
+ if (n_tokens < 0 ) {
253
+ result.resize (-n_tokens);
254
+ int check = llama_tokenize (model, text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
255
+ GGML_ASSERT (check == -n_tokens);
256
+ } else {
257
+ result.resize (n_tokens);
258
+ }
259
+ return result;
260
+ }
261
+
262
+ std::string llama_detokenize (llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
263
+ std::string text;
264
+ text.resize (std::max (text.capacity (), tokens.size ()));
265
+ int32_t n_chars = llama_detokenize (llama_get_model (ctx), tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
266
+ if (n_chars < 0 ) {
267
+ text.resize (-n_chars);
268
+ n_chars = llama_detokenize (llama_get_model (ctx), tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
269
+ GGML_ASSERT (n_chars <= (int32_t )text.size ()); // whitespace trimming is performed after per-token detokenization
270
+ }
271
+
272
+ text.resize (n_chars);
273
+
274
+ // NOTE: the original tokenizer decodes bytes after collecting the pieces.
275
+ return text;
276
+ }
277
+
278
+ std::string llama_detokenize_single (llama_context * ctx, llama_token token, bool special) {
279
+ std::vector<llama_token> tokens = {token};
280
+ return llama_detokenize (ctx, tokens, special);
281
+ }
238
282
239
- // get the last token
240
- auto last_token = last_tokens[last_tokens_size - 1 ];
283
+ // Constants for preventing overflow
284
+ const float FLOAT_MAX_LOG = 88 .7228391f ;
285
+ const int MAX_CHAR_LEN = 40 ;
286
+ const int MAX_SEQ_LEN = 20 ;
241
287
242
- // if last token is part of the sequence breakers, skip whole sampler
243
- if (std::find (dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {
288
+
289
+ void llama_sample_dry_impl (struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers) {
290
+ if (last_tokens_size < 1 ) {
244
291
return ;
245
292
}
246
293
247
- // create an unordered map of "next tokens" <-> max match length
294
+ // Cache for token-to-string conversions
295
+ std::unordered_map<llama_token, std::string> token_to_string_cache;
296
+ // Store sequence breakers for more efficient lookup
297
+ std::unordered_multimap<std::string, std::vector<std::string>> restart_sequences;
298
+
299
+ auto detokenize_with_cache = [&](llama_token token) -> std::string {
300
+ auto it = token_to_string_cache.find (token);
301
+ if (it != token_to_string_cache.end ()) {
302
+ return it->second ;
303
+ }
304
+ std::string token_str = llama_detokenize_single (ctx, token, false );
305
+ token_to_string_cache[token] = token_str;
306
+ return token_str;
307
+ };
308
+
309
+ // Pre-process dry_seq_breakers
310
+ for (const auto & breaker : dry_seq_breakers) {
311
+ std::string breaker_trimmed = breaker.substr (0 , MAX_CHAR_LEN);
312
+ std::vector<llama_token> tokens = llama_tokenize (ctx, breaker_trimmed, false , false );
313
+
314
+ if (!tokens.empty ()) {
315
+ std::string head = detokenize_with_cache (tokens[0 ]);
316
+ std::vector<std::string> tail;
317
+
318
+ for (size_t i = 1 ; i < tokens.size () && i <= MAX_SEQ_LEN; ++i) {
319
+ tail.push_back (detokenize_with_cache (tokens[i]));
320
+ }
321
+ restart_sequences.emplace (head, tail);
322
+ }
323
+ }
324
+
325
+ // Find max repetition length considering restart sequences
326
+ int rep_limit = last_tokens_size;
327
+
328
+ for (size_t i = 0 ; i < last_tokens_size; ++i) {
329
+ size_t ix = last_tokens_size - 1 - i;
330
+ std::string token_str = detokenize_with_cache (last_tokens[ix]);
331
+
332
+ // Check if the token is a potential sequence breaker
333
+ auto its = restart_sequences.equal_range (token_str);
334
+ if (its.first == restart_sequences.end ()) continue ;
335
+
336
+ int longest_match = -1 ;
337
+ // Check all potential sequence breakers starting with this token
338
+ for (auto it = its.first ; it != its.second ; ++it) {
339
+ int seq_len = (int )it->second .size ();
340
+ if (seq_len > longest_match && seq_len <= i) {
341
+ bool match = true ;
342
+ // Check if the following tokens match the sequence breaker
343
+ for (size_t offset = 0 ; offset < seq_len; ++offset) {
344
+ if (it->second [offset] != detokenize_with_cache (last_tokens[ix + 1 + offset])) {
345
+ match = false ;
346
+ break ;
347
+ }
348
+ }
349
+ if (match) {
350
+ longest_match = seq_len;
351
+ }
352
+ }
353
+ }
354
+
355
+ if (longest_match >= 0 ) {
356
+ rep_limit = static_cast <int >(i) - longest_match;
357
+ break ;
358
+ }
359
+ }
360
+
361
+ if (rep_limit <= dry_allowed_length) {
362
+ return ;
363
+ }
364
+
365
+ // Store max match length for each token
248
366
std::unordered_map<llama_token, size_t > match_lengths;
249
367
250
- // loop through each previous token (exclude the last token)
368
+ // Find repeated sequences
251
369
for (size_t i = 0 ; i < last_tokens_size - 1 ; ++i) {
252
- // skip if the compare token is not the same as the last token
253
- if (last_tokens[i] != last_token) {
370
+ if (last_tokens[i] != last_tokens[last_tokens_size - 1 ]) {
254
371
continue ;
255
372
}
256
373
257
- // get the next token (i + 1 is always less than last_tokens_size)
258
374
auto next_token = last_tokens[i + 1 ];
375
+ std::string next_token_str = detokenize_with_cache (next_token);
259
376
260
- // if next token is part of the sequence breakers, skip
261
- if (std::find (dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
377
+ // Skip if next token is a sequence breaker
378
+ auto its = restart_sequences.equal_range (next_token_str);
379
+ if (its.first != restart_sequences.end ()) {
262
380
continue ;
263
381
}
264
382
265
- // try to extend the match backwards (match length starts at 1 because last token is already matched)
266
383
size_t match_length = 1 ;
267
384
268
- // loop through the previous tokens
385
+ // Extend match as far as possible
269
386
for (;; match_length++) {
270
- // if we have reached the start of our last tokens, break
271
- if (i < match_length) break ;
387
+ if (i < match_length || match_length > rep_limit) {
388
+ break ;
389
+ }
272
390
273
- // compare token starts at our prev index, going backwards by match length
274
391
auto compare_token = last_tokens[i - match_length];
392
+ std::string compare_token_str = detokenize_with_cache (compare_token);
275
393
276
- // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
277
394
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
395
+ std::string head_token_str = detokenize_with_cache (head_token);
278
396
279
- // break out of the match if any tokens don't match
280
- if (compare_token != head_token) {
397
+ if (compare_token_str != head_token_str) {
281
398
break ;
282
399
}
283
400
284
- // if compare token is part of the sequence breakers, break out of the match
285
- if (std::find (dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
401
+ // Check if we've hit a sequence breaker
402
+ its = restart_sequences.equal_range (compare_token_str);
403
+ if (its.first != restart_sequences.end ()) {
286
404
break ;
287
405
}
288
406
}
289
407
290
- // Check if the next token exists in the map
408
+ // Update max match length for this token
291
409
auto it = match_lengths.find (next_token);
292
-
293
410
if (it == match_lengths.end ()) {
294
- // Key does not exist, insert the new value
295
411
match_lengths[next_token] = match_length;
296
412
} else {
297
- // Key exists, update it with the max of the new value or the existing value
298
413
it->second = std::max (it->second , match_length);
299
414
}
300
415
}
301
416
302
- // apply penalties
417
+ // Calculate max safe exponent
418
+ int max_exponent = 0 ;
419
+ if (dry_base > 1 .000001f ) {
420
+ max_exponent = static_cast <int >(FLOAT_MAX_LOG / log (dry_base));
421
+ }
422
+
423
+ #ifdef DEBUG
424
+ LLAMA_LOG_INFO (" DRY Sampling parameters:\n " );
425
+ LLAMA_LOG_INFO (" dry_base: %f\n " , dry_base);
426
+ LLAMA_LOG_INFO (" dry_multiplier: %f\n " , dry_multiplier);
427
+ LLAMA_LOG_INFO (" dry_allowed_length: %d\n " , dry_allowed_length);
428
+ LLAMA_LOG_INFO (" max_exponent: %d\n " , max_exponent);
429
+ LLAMA_LOG_INFO (" DRY penalties [" );
430
+ #endif
431
+
432
+ // Apply penalties
303
433
for (const auto & pair : match_lengths) {
304
434
auto next_token = pair.first ;
305
435
auto match_length = pair.second ;
306
436
307
- // if the match length is greater than or equal to our allowed length in config, we apply penalities
308
- if (match_length >= (size_t )dry_allowed_length) {
309
-
310
- // find our next token in the candidates->data
437
+ if (match_length >= static_cast <size_t >(dry_allowed_length)) {
311
438
for (size_t i = 0 ; i < candidates->size ; ++i) {
312
439
if (candidates->data [i].id == next_token) {
313
- // calculate the penalty
314
- float penalty = dry_multiplier * pow (dry_base, match_length - dry_allowed_length);
315
-
316
- // apply the dry penalty
440
+ int repeat_exp = static_cast <int >(match_length - dry_allowed_length);
441
+ if (max_exponent > 0 && repeat_exp > max_exponent) {
442
+ repeat_exp = max_exponent;
443
+ }
444
+ float penalty = dry_multiplier * pow (dry_base, static_cast <float >(repeat_exp));
317
445
candidates->data [i].logit -= penalty;
446
+
447
+ #ifdef DEBUG
448
+ LLAMA_LOG_INFO (" Token %d: %s (Penalty: %.2f)" , next_token, detokenize_with_cache (next_token).c_str (), penalty);
449
+ #endif
318
450
break ;
319
451
}
320
452
}
321
453
}
322
454
}
455
+
456
+ #ifdef DEBUG
457
+ LLAMA_LOG_INFO (" ]\n " );
458
+ #endif
323
459
}
324
460
325
461
void llama_sample_tail_free_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
0 commit comments