Skip to content

Commit 3e858d4

Browse files
VJHackggerganov
authored andcommitted
main : option to disable context shift (ggml-org#9484)
* added cli arg to disable context shift * reverted precommit * updated README.md for main * white space * allow disabling context shift in the server * Update common/arg.cpp no-context-shift only works for main example Co-authored-by: Georgi Gerganov <[email protected]> * added server example to --no-context-shift args * removed server changes * white space --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent aacd3aa commit 3e858d4

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

common/arg.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
685685
params.n_keep = value;
686686
}
687687
));
688+
add_opt(llama_arg(
689+
{"--no-context-shift"},
690+
format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
691+
[](gpt_params & params) {
692+
params.ctx_shift = false;
693+
}
694+
).set_examples({LLAMA_EXAMPLE_MAIN}));
688695
add_opt(llama_arg(
689696
{"--chunks"}, "N",
690697
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
@@ -1985,4 +1992,3 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
19851992

19861993
return ctx_arg;
19871994
}
1988-

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ struct gpt_params {
246246
bool cont_batching = true; // insert new sequences for decoding on-the-fly
247247
bool flash_attn = false; // flash attention
248248
bool no_perf = false; // disable performance metrics
249+
bool ctx_shift = true; // context shift on inifinite text generation
249250

250251
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
251252
bool logits_all = false; // return logits for all tokens in the batch

examples/main/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ A value of -1 will enable infinite text generation, even though we have a finite
161161

162162
If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled.
163163

164+
The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full.
165+
164166
It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter.
165167

166168
### Temperature

examples/main/main.cpp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -559,29 +559,35 @@ int main(int argc, char ** argv) {
559559
// if we run out of context:
560560
// - take the n_keep first tokens from the original prompt (via n_past)
561561
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
562+
562563
if (n_past + (int) embd.size() >= n_ctx) {
563-
if (params.n_predict == -2) {
564-
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
564+
if (!params.ctx_shift){
565+
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
565566
break;
566-
}
567+
} else {
568+
if (params.n_predict == -2) {
569+
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
570+
break;
571+
}
567572

568-
const int n_left = n_past - params.n_keep;
569-
const int n_discard = n_left/2;
573+
const int n_left = n_past - params.n_keep;
574+
const int n_discard = n_left/2;
570575

571-
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
572-
n_past, n_left, n_ctx, params.n_keep, n_discard);
576+
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
577+
n_past, n_left, n_ctx, params.n_keep, n_discard);
573578

574-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
575-
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
579+
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
580+
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
576581

577-
n_past -= n_discard;
582+
n_past -= n_discard;
578583

579-
LOG_DBG("after swap: n_past = %d\n", n_past);
584+
LOG_DBG("after swap: n_past = %d\n", n_past);
580585

581-
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
586+
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
582587

583-
LOG_DBG("clear session path\n");
584-
path_session.clear();
588+
LOG_DBG("clear session path\n");
589+
path_session.clear();
590+
}
585591
}
586592
} else {
587593
// context extension via Self-Extend

0 commit comments

Comments
 (0)