@@ -582,8 +582,9 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
582
582
bool input_echo = true ;
583
583
bool need_to_save_session = !path_session.empty () && n_matching_session_tokens < embd_inp.size ();
584
584
585
- int n_past = 0 ;
585
+ int n_past = 0 ; // n_past tells eval() which position in KV we are at
586
586
int n_past_system = 0 ; // not in use
587
+ int n_past_total = 0 ; // n_past_total does not reset on context switches
587
588
int n_remain = params.n_predict ;
588
589
int n_consumed = 0 ;
589
590
int n_session_consumed = 0 ;
@@ -703,6 +704,7 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
703
704
704
705
n_past++;
705
706
n_session_consumed++;
707
+ n_past_total++;
706
708
707
709
if (n_session_consumed >= (int ) session_tokens.size ()) {
708
710
++i;
@@ -763,6 +765,7 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
763
765
return 1 ;
764
766
}
765
767
n_past += n_eval;
768
+ n_past_total += n_eval;
766
769
}
767
770
if (embd.size () > 0 && !path_session.empty ()) {
768
771
session_tokens.insert (session_tokens.end (), embd.begin (), embd.end ());
@@ -946,7 +949,7 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
946
949
947
950
bool stopword_fulfilled = false ;
948
951
// stopwords
949
- if (!embd.empty () && n_past > embd_inp.size ())
952
+ if (!embd.empty () && n_past_total > embd_inp.size ())
950
953
{
951
954
for (const auto & stopword : stopwords)
952
955
{
@@ -1122,7 +1125,7 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
1122
1125
#endif
1123
1126
1124
1127
// end of text token or stopword detected in generated content
1125
- if ((!embd.empty () && embd.back () == falcon_token_eos () && n_past > embd_inp.size ()) || stopword_fulfilled)
1128
+ if ((!embd.empty () && embd.back () == falcon_token_eos () && n_past_total > embd_inp.size ()) || stopword_fulfilled)
1126
1129
{
1127
1130
if (params.instruct )
1128
1131
{
@@ -1132,7 +1135,7 @@ fprintf(stderr, "+------------+-------+-------+-------+-------+---------------+-
1132
1135
if (params.verbose_prompt )
1133
1136
fprintf (stderr, " [end of text]\n " );
1134
1137
// if we are in the prompt ingestion we will not stop
1135
- if (n_past > (int )embd_inp.size ()) {
1138
+ if (n_past_total > (int )embd_inp.size ()) {
1136
1139
break ;
1137
1140
}
1138
1141
}
0 commit comments