@@ -176,8 +176,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
176
176
}
177
177
}
178
178
179
- const ggml_type wtype2 = GGML_TYPE_F32;
180
-
181
179
auto & ctx = model.ctx ;
182
180
183
181
size_t ctx_size = 0 ;
@@ -237,7 +235,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
237
235
238
236
const int n_embd = hparams.n_embd ;
239
237
const int n_layer = hparams.n_layer ;
240
- const int n_ctx = hparams.n_ctx ;
241
238
const int n_vocab = hparams.n_vocab ;
242
239
243
240
model.layers .resize (n_layer);
@@ -539,9 +536,7 @@ bool llama_eval(
539
536
const int n_vocab = hparams.n_vocab ;
540
537
const int n_rot = hparams.n_embd /hparams.n_head ;
541
538
542
- const int d_key = n_embd/n_head;
543
-
544
- // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
539
+ // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
545
540
// static size_t buf_size = hparams.n_ctx*1024*1024;
546
541
static size_t buf_size = 512u *1024 *1024 ;
547
542
static void * buf = malloc (buf_size);
@@ -792,7 +787,7 @@ int main(int argc, char ** argv) {
792
787
if (gpt_params_parse (argc, argv, params) == false ) {
793
788
return 1 ;
794
789
}
795
-
790
+
796
791
if (params.n_ctx > 2048 ) {
797
792
fprintf (stderr, " %s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
798
793
" expect poor results\n " , __func__, params.n_ctx );
@@ -820,7 +815,7 @@ int main(int argc, char ** argv) {
820
815
// load the model
821
816
{
822
817
const int64_t t_start_us = ggml_time_us ();
823
- if (!llama_model_load (params.model , model, vocab, params.n_ctx )) {
818
+ if (!llama_model_load (params.model , model, vocab, params.n_ctx )) {
824
819
fprintf (stderr, " %s: failed to load model from '%s'\n " , __func__, params.model .c_str ());
825
820
return 1 ;
826
821
}
@@ -849,9 +844,25 @@ int main(int argc, char ** argv) {
849
844
850
845
params.n_predict = std::min (params.n_predict , model.hparams .n_ctx - (int ) embd_inp.size ());
851
846
847
+ // prefix & suffix for instruct mode
848
+ const std::vector<gpt_vocab::id> inp_pfx = ::llama_tokenize (vocab, " \n\n ### Instruction:\n\n " , true );
849
+ const std::vector<gpt_vocab::id> inp_sfx = ::llama_tokenize (vocab, " \n\n ### Response:\n\n " , false );
850
+
851
+ // in instruct mode, we inject a prefix and a suffix to each input by the user
852
+ if (params.instruct ) {
853
+ fprintf (stderr, " == Instruction mode enabled ==\n " );
854
+ params.interactive = true ;
855
+ params.antiprompt = " ### Instruction:\n\n " ;
856
+ }
857
+
852
858
// tokenize the reverse prompt
853
859
std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize (vocab, params.antiprompt , false );
854
860
861
+ // enable interactive mode if reverse prompt is specified
862
+ if (!antiprompt_inp.empty ()) {
863
+ params.interactive = true ;
864
+ }
865
+
855
866
fprintf (stderr, " \n " );
856
867
fprintf (stderr, " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
857
868
fprintf (stderr, " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
@@ -872,7 +883,7 @@ int main(int argc, char ** argv) {
872
883
873
884
fprintf (stderr, " %s: interactive mode on.\n " , __func__);
874
885
875
- if (antiprompt_inp.size ()) {
886
+ if (antiprompt_inp.size ()) {
876
887
fprintf (stderr, " %s: reverse prompt: '%s'\n " , __func__, params.antiprompt .c_str ());
877
888
fprintf (stderr, " %s: number of tokens in reverse prompt = %zu\n " , __func__, antiprompt_inp.size ());
878
889
for (int i = 0 ; i < (int ) antiprompt_inp.size (); i++) {
@@ -894,31 +905,27 @@ int main(int argc, char ** argv) {
894
905
std::vector<gpt_vocab::id> last_n_tokens (last_n_size);
895
906
std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
896
907
897
-
898
908
if (params.interactive ) {
899
909
fprintf (stderr, " == Running in interactive mode. ==\n "
900
910
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
901
911
" - Press Ctrl+C to interject at any time.\n "
902
912
#endif
903
913
" - Press Return to return control to LLaMa.\n "
904
- " - If you want to submit another line, end your input in '\\ '.\n " );
914
+ " - If you want to submit another line, end your input in '\\ '.\n\n " );
915
+ is_interacting = true ;
905
916
}
906
917
907
- int remaining_tokens = params.n_predict ;
908
918
int input_consumed = 0 ;
909
919
bool input_noecho = false ;
910
920
911
- // prompt user immediately after the starting prompt has been loaded
912
- if (params.interactive_start ) {
913
- is_interacting = true ;
914
- }
921
+ int remaining_tokens = params.n_predict ;
915
922
916
923
// set the color for the prompt which will be output initially
917
924
if (params.use_color ) {
918
925
printf (ANSI_COLOR_YELLOW);
919
926
}
920
927
921
- while (remaining_tokens > 0 ) {
928
+ while (remaining_tokens > 0 || params. interactive ) {
922
929
// predict
923
930
if (embd.size () > 0 ) {
924
931
const int64_t t_start_us = ggml_time_us ();
@@ -971,13 +978,13 @@ int main(int argc, char ** argv) {
971
978
last_n_tokens.erase (last_n_tokens.begin ());
972
979
last_n_tokens.push_back (embd_inp[input_consumed]);
973
980
++input_consumed;
974
- if (embd.size () > params.n_batch ) {
981
+ if (( int ) embd.size () > params.n_batch ) {
975
982
break ;
976
983
}
977
984
}
978
985
979
986
// reset color to default if we there is no pending user input
980
- if (!input_noecho && params.use_color && embd_inp.size () == input_consumed) {
987
+ if (!input_noecho && params.use_color && ( int ) embd_inp.size () == input_consumed) {
981
988
printf (ANSI_COLOR_RESET);
982
989
}
983
990
}
@@ -999,19 +1006,26 @@ int main(int argc, char ** argv) {
999
1006
is_interacting = true ;
1000
1007
}
1001
1008
if (is_interacting) {
1009
+ if (params.instruct ) {
1010
+ input_consumed = embd_inp.size ();
1011
+ embd_inp.insert (embd_inp.end (), inp_pfx.begin (), inp_pfx.end ());
1012
+
1013
+ printf (" \n > " );
1014
+ }
1015
+
1002
1016
// currently being interactive
1003
- bool another_line= true ;
1017
+ bool another_line = true ;
1004
1018
while (another_line) {
1005
1019
fflush (stdout);
1006
1020
char buf[256 ] = {0 };
1007
1021
int n_read;
1008
- if (params.use_color ) printf (ANSI_BOLD ANSI_COLOR_GREEN);
1022
+ if (params.use_color ) printf (ANSI_BOLD ANSI_COLOR_GREEN);
1009
1023
if (scanf (" %255[^\n ]%n%*c" , buf, &n_read) <= 0 ) {
1010
1024
// presumable empty line, consume the newline
1011
1025
std::ignore = scanf (" %*c" );
1012
1026
n_read=0 ;
1013
1027
}
1014
- if (params.use_color ) printf (ANSI_COLOR_RESET);
1028
+ if (params.use_color ) printf (ANSI_COLOR_RESET);
1015
1029
1016
1030
if (n_read > 0 && buf[n_read-1 ]==' \\ ' ) {
1017
1031
another_line = true ;
@@ -1026,6 +1040,10 @@ int main(int argc, char ** argv) {
1026
1040
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize (vocab, buf, false );
1027
1041
embd_inp.insert (embd_inp.end (), line_inp.begin (), line_inp.end ());
1028
1042
1043
+ if (params.instruct ) {
1044
+ embd_inp.insert (embd_inp.end (), inp_sfx.begin (), inp_sfx.end ());
1045
+ }
1046
+
1029
1047
remaining_tokens -= line_inp.size ();
1030
1048
1031
1049
input_noecho = true ; // do not echo this again
@@ -1037,8 +1055,12 @@ int main(int argc, char ** argv) {
1037
1055
1038
1056
// end of text token
1039
1057
if (embd.back () == 2 ) {
1040
- fprintf (stderr, " [end of text]\n " );
1041
- break ;
1058
+ if (params.interactive ) {
1059
+ is_interacting = true ;
1060
+ } else {
1061
+ fprintf (stderr, " [end of text]\n " );
1062
+ break ;
1063
+ }
1042
1064
}
1043
1065
}
1044
1066
0 commit comments