Skip to content

Commit 0b0ca76

Browse files
committed
Enhance user input handling for llama-run
The main motivation for this change is it was not handing ctrl-d correctly. Modify `read_user_input` to handle EOF, "/bye" command, and empty input cases. Introduce `get_user_input` function to manage user input loop and handle different return cases. Signed-off-by: Eric Curtin <[email protected]>
1 parent 99a3755 commit 0b0ca76

File tree

1 file changed

+39
-7
lines changed

1 file changed

+39
-7
lines changed

examples/run/run.cpp

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
801801

802802
static int read_user_input(std::string & user) {
803803
std::getline(std::cin, user);
804-
return user.empty(); // Should have data in happy path
804+
if (std::cin.eof()) {
805+
printf("\n");
806+
return 1;
807+
}
808+
809+
if (user == "/bye") {
810+
return 1;
811+
}
812+
813+
if (user.empty()) {
814+
return 2;
815+
}
816+
817+
return 0; // Should have data in happy path
805818
}
806819

807820
// Function to generate a response based on the prompt
@@ -868,26 +881,45 @@ static bool is_stdout_a_terminal() {
868881
#endif
869882
}
870883

871-
// Function to tokenize the prompt
884+
// Function to handle user input
885+
static int get_user_input(std::string & user_input, const std::string & user) {
886+
while (true) {
887+
const int ret = handle_user_input(user_input, user);
888+
if (ret == 1) {
889+
return 1;
890+
}
891+
892+
if (ret == 2) {
893+
continue;
894+
}
895+
896+
break;
897+
}
898+
899+
return 0;
900+
}
901+
902+
// Main chat loop function
872903
static int chat_loop(LlamaData & llama_data, const std::string & user) {
873904
int prev_len = 0;
874905
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
875906
static const bool stdout_a_terminal = is_stdout_a_terminal();
876907
while (true) {
877-
// Get user input
878908
std::string user_input;
879-
while (handle_user_input(user_input, user)) {
909+
if (get_user_input(user_input, user) == 1) {
910+
return 0;
880911
}
881912

882913
add_message("user", user.empty() ? user_input : user, llama_data);
914+
883915
int new_len;
884-
if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
916+
if (apply_chat_template_with_error_handling(llama_data, true, new_len) == 1) {
885917
return 1;
886918
}
887919

888920
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
889921
std::string response;
890-
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
922+
if (generate_response(llama_data, prompt, response, stdout_a_terminal) == 1) {
891923
return 1;
892924
}
893925

@@ -896,7 +928,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
896928
}
897929

898930
add_message("assistant", response, llama_data);
899-
if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) {
931+
if (apply_chat_template_with_error_handling(llama_data, false, prev_len) == 1) {
900932
return 1;
901933
}
902934
}

0 commit comments

Comments
 (0)