Skip to content

Commit a31ce1a

Browse files
committed
added cmd to main for dry sampler
1 parent d8b47da commit a31ce1a

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

common/common.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,30 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
582582
sparams.penalty_present = std::stof(argv[i]);
583583
return true;
584584
}
585+
if (arg == "--dry-multiplier") {
586+
if (++i >= argc) {
587+
invalid_param = true;
588+
return true;
589+
}
590+
sparams.dry_multiplier = std::stof(argv[i]);
591+
return true;
592+
}
593+
if (arg == "--dry-base") {
594+
if (++i >= argc) {
595+
invalid_param = true;
596+
return true;
597+
}
598+
sparams.dry_base = std::stoi(argv[i]);
599+
return true;
600+
}
601+
if (arg == "--dry-allowed-length") {
602+
if (++i >= argc) {
603+
invalid_param = true;
604+
return true;
605+
}
606+
sparams.dry_allowed_length = std::stoi(argv[i]);
607+
return true;
608+
}
585609
if (arg == "--dynatemp-range") {
586610
if (++i >= argc) {
587611
invalid_param = true;
@@ -1425,6 +1449,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
14251449
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
14261450
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
14271451
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
1452+
printf(" --dry-multiplier N DRY sampler multiplier (default: %.1f, 0.0 = disabled)\n", (double)sparams.dry_multiplier);
1453+
printf(" --dry-base N DRY sampler base (default: %.1f)\n", (double)sparams.dry_base);
1454+
printf(" --dry-allowed-length N\n");
1455+
printf(" DRY sampler allowed length (default: %d)\n", sparams.dry_allowed_length);
14281456
printf(" --dynatemp-range N dynamic temperature range (default: %.1f, 0.0 = disabled)\n", (double)sparams.dynatemp_range);
14291457
printf(" --dynatemp-exp N dynamic temperature exponent (default: %.1f)\n", (double)sparams.dynatemp_exponent);
14301458
printf(" --mirostat N use Mirostat sampling.\n");

common/sampling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
9999
snprintf(result, sizeof(result),
100100
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
101101
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
102-
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
102+
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, dry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d",
103103
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
104104
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
105-
params.mirostat, params.mirostat_eta, params.mirostat_tau);
105+
params.mirostat, params.mirostat_eta, params.mirostat_tau, params.dry_multiplier, params.dry_base, params.dry_allowed_length);
106106

107107
return std::string(result);
108108
}

0 commit comments

Comments
 (0)