@@ -801,7 +801,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
801
801
802
802
static int read_user_input (std::string & user) {
803
803
std::getline (std::cin, user);
804
- return user.empty (); // Should have data in happy path
804
+ if (std::cin.eof ()) {
805
+ printf (" \n " );
806
+ return 1 ;
807
+ }
808
+
809
+ if (user == " /bye" ) {
810
+ return 1 ;
811
+ }
812
+
813
+ if (user.empty ()) {
814
+ return 2 ;
815
+ }
816
+
817
+ return 0 ; // Should have data in happy path
805
818
}
806
819
807
820
// Function to generate a response based on the prompt
@@ -868,26 +881,45 @@ static bool is_stdout_a_terminal() {
868
881
#endif
869
882
}
870
883
871
- // Function to tokenize the prompt
884
+ // Function to handle user input
885
+ static int get_user_input (std::string & user_input, const std::string & user) {
886
+ while (true ) {
887
+ const int ret = handle_user_input (user_input, user);
888
+ if (ret == 1 ) {
889
+ return 1 ;
890
+ }
891
+
892
+ if (ret == 2 ) {
893
+ continue ;
894
+ }
895
+
896
+ break ;
897
+ }
898
+
899
+ return 0 ;
900
+ }
901
+
902
+ // Main chat loop function
872
903
static int chat_loop (LlamaData & llama_data, const std::string & user) {
873
904
int prev_len = 0 ;
874
905
llama_data.fmtted .resize (llama_n_ctx (llama_data.context .get ()));
875
906
static const bool stdout_a_terminal = is_stdout_a_terminal ();
876
907
while (true ) {
877
- // Get user input
878
908
std::string user_input;
879
- while (handle_user_input (user_input, user)) {
909
+ if (get_user_input (user_input, user) == 1 ) {
910
+ return 0 ;
880
911
}
881
912
882
913
add_message (" user" , user.empty () ? user_input : user, llama_data);
914
+
883
915
int new_len;
884
- if (apply_chat_template_with_error_handling (llama_data, true , new_len) < 0 ) {
916
+ if (apply_chat_template_with_error_handling (llama_data, true , new_len) == 1 ) {
885
917
return 1 ;
886
918
}
887
919
888
920
std::string prompt (llama_data.fmtted .begin () + prev_len, llama_data.fmtted .begin () + new_len);
889
921
std::string response;
890
- if (generate_response (llama_data, prompt, response, stdout_a_terminal)) {
922
+ if (generate_response (llama_data, prompt, response, stdout_a_terminal) == 1 ) {
891
923
return 1 ;
892
924
}
893
925
@@ -896,7 +928,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
896
928
}
897
929
898
930
add_message (" assistant" , response, llama_data);
899
- if (apply_chat_template_with_error_handling (llama_data, false , prev_len) < 0 ) {
931
+ if (apply_chat_template_with_error_handling (llama_data, false , prev_len) == 1 ) {
900
932
return 1 ;
901
933
}
902
934
}
0 commit comments