@@ -855,14 +855,18 @@ int main(int argc, char ** argv) {
855
855
// in instruct mode, we inject a prefix and a suffix to each input by the user
856
856
if (params.instruct ) {
857
857
params.interactive = true ;
858
- params.antiprompt = " ### Instruction:\n\n " ;
858
+ params.antiprompt . push_back ( " ### Instruction:\n\n " ) ;
859
859
}
860
860
861
861
// tokenize the reverse prompt
862
- std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize (vocab, params.antiprompt , false );
862
+ std::vector<std::vector<gpt_vocab::id>> antipromptv_inp;
863
+
864
+ for (auto antiprompt : params.antiprompt ) {
865
+ antipromptv_inp.push_back (::llama_tokenize (vocab, antiprompt, false ));
866
+ }
863
867
864
868
// enable interactive mode if reverse prompt is specified
865
- if (!antiprompt_inp. empty ()) {
869
+ if (!antipromptv_inp. size ()) {
866
870
params.interactive = true ;
867
871
}
868
872
@@ -886,13 +890,16 @@ int main(int argc, char ** argv) {
886
890
887
891
fprintf (stderr, " %s: interactive mode on.\n " , __func__);
888
892
889
- if (antiprompt_inp.size ()) {
890
- fprintf (stderr, " %s: reverse prompt: '%s'\n " , __func__, params.antiprompt .c_str ());
891
- fprintf (stderr, " %s: number of tokens in reverse prompt = %zu\n " , __func__, antiprompt_inp.size ());
892
- for (int i = 0 ; i < (int ) antiprompt_inp.size (); i++) {
893
- fprintf (stderr, " %6d -> '%s'\n " , antiprompt_inp[i], vocab.id_to_token .at (antiprompt_inp[i]).c_str ());
893
+ if (antipromptv_inp.size ()) {
894
+ for (size_t apindex = 0 ; apindex < antipromptv_inp.size (); ++apindex) {
895
+ auto antiprompt_inp = antipromptv_inp.at (apindex);
896
+ fprintf (stderr, " %s: reverse prompt: '%s'\n " , __func__, params.antiprompt .at (apindex).c_str ());
897
+ fprintf (stderr, " %s: number of tokens in reverse prompt = %zu\n " , __func__, antiprompt_inp.size ());
898
+ for (int i = 0 ; i < (int ) antiprompt_inp.size (); i++) {
899
+ fprintf (stderr, " %6d -> '%s'\n " , antiprompt_inp[i], vocab.id_to_token .at (antiprompt_inp[i]).c_str ());
900
+ }
901
+ fprintf (stderr, " \n " );
894
902
}
895
- fprintf (stderr, " \n " );
896
903
}
897
904
}
898
905
fprintf (stderr, " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
@@ -1009,9 +1016,12 @@ int main(int argc, char ** argv) {
1009
1016
// check if we should prompt the user for more
1010
1017
if (params.interactive && embd_inp.size () <= input_consumed) {
1011
1018
// check for reverse prompt
1012
- if (antiprompt_inp.size () && std::equal (antiprompt_inp.rbegin (), antiprompt_inp.rend (), last_n_tokens.rbegin ())) {
1013
- // reverse prompt found
1014
- is_interacting = true ;
1019
+ for (auto antiprompt_inp : antipromptv_inp) {
1020
+ if (antiprompt_inp.size () && std::equal (antiprompt_inp.rbegin (), antiprompt_inp.rend (), last_n_tokens.rbegin ())) {
1021
+ // reverse prompt found
1022
+ is_interacting = true ;
1023
+ break ;
1024
+ }
1015
1025
}
1016
1026
if (is_interacting) {
1017
1027
if (params.instruct ) {
0 commit comments