Skip to content

Commit 8dae7ce

Browse files
authored
Add --cfg-negative-prompt-file option for examples (#2591)
Add --cfg-negative-prompt-file option for examples
1 parent a73ccf1 commit 8dae7ce

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

examples/common.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
274274
break;
275275
}
276276
params.cfg_negative_prompt = argv[i];
277+
} else if (arg == "--cfg-negative-prompt-file") {
278+
if (++i >= argc) {
279+
invalid_param = true;
280+
break;
281+
}
282+
std::ifstream file(argv[i]);
283+
if (!file) {
284+
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
285+
invalid_param = true;
286+
break;
287+
}
288+
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.cfg_negative_prompt));
289+
if (params.cfg_negative_prompt.back() == '\n') {
290+
params.cfg_negative_prompt.pop_back();
291+
}
277292
} else if (arg == "--cfg-scale") {
278293
if (++i >= argc) {
279294
invalid_param = true;
@@ -567,8 +582,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
567582
fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
568583
fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
569584
fprintf(stdout, " --grammar-file FNAME file to read grammar from\n");
570-
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
585+
fprintf(stdout, " --cfg-negative-prompt PROMPT\n");
571586
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
587+
fprintf(stdout, " --cfg-negative-prompt-file FNAME\n");
588+
fprintf(stdout, " negative prompt file to use for guidance. (default: empty)\n");
572589
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
573590
fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale);
574591
fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);

0 commit comments

Comments
 (0)