Skip to content

Commit 09f2184

Browse files
author
John
committed
bugfix for stopword/eos condition after context switch
1 parent bddb0cc commit 09f2184

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

examples/falcon/falcon_main.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,9 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
582582
bool input_echo = true;
583583
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();
584584

585-
int n_past = 0;
585+
int n_past = 0; // n_past tells eval() which position in KV we are at
586586
int n_past_system = 0; // not in use
587+
int n_past_total = 0; // n_past_total does not reset on context switches
587588
int n_remain = params.n_predict;
588589
int n_consumed = 0;
589590
int n_session_consumed = 0;
@@ -703,6 +704,7 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
703704

704705
n_past++;
705706
n_session_consumed++;
707+
n_past_total++;
706708

707709
if (n_session_consumed >= (int) session_tokens.size()) {
708710
++i;
@@ -763,6 +765,7 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
763765
return 1;
764766
}
765767
n_past += n_eval;
768+
n_past_total += n_eval;
766769
}
767770
if (embd.size() > 0 && !path_session.empty()) {
768771
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
@@ -946,7 +949,7 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
946949

947950
bool stopword_fulfilled = false;
948951
// stopwords
949-
if (!embd.empty() && n_past > embd_inp.size())
952+
if (!embd.empty() && n_past_total > embd_inp.size())
950953
{
951954
for (const auto& stopword : stopwords)
952955
{
@@ -1122,7 +1125,7 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
11221125
#endif
11231126

11241127
// end of text token or stopword detected in generated content
1125-
if ((!embd.empty() && embd.back() == falcon_token_eos() && n_past > embd_inp.size()) || stopword_fulfilled)
1128+
if ((!embd.empty() && embd.back() == falcon_token_eos() && n_past_total > embd_inp.size()) || stopword_fulfilled)
11261129
{
11271130
if (params.instruct)
11281131
{
@@ -1132,7 +1135,7 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
11321135
if (params.verbose_prompt)
11331136
fprintf(stderr, " [end of text]\n");
11341137
// if we are in the prompt ingestion we will not stop
1135-
if (n_past > (int)embd_inp.size()) {
1138+
if (n_past_total > (int)embd_inp.size()) {
11361139
break;
11371140
}
11381141
}

libfalcon.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2172,7 +2172,6 @@ static bool falcon_eval_internal(
21722172
ggml_set_name(Vcur, "Vcur");
21732173

21742174
// using mode = 2 for neox mode
2175-
// Qcur->meta.f_custom[GGML_CUSTOM_F_ROPE_ANG_SCALE] = 0.25; Kcur->meta.f_custom[GGML_CUSTOM_F_ROPE_ANG_FREQ] = 0.25;
21762175
Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, head_dim, 2,n_ctx);
21772176
Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, head_dim, 2,n_ctx);
21782177

0 commit comments

Comments
 (0)