Skip to content

Commit f777a73

Browse files
authored
Some llama-run cleanups (#11973)
Use consolidated open function call from File class. Change read_all to to_string(). Remove exclusive locking, the intent for that lock is to avoid multiple processes writing to the same file, it's not an issue for readers, although we may want to consider adding a shared lock. Remove passing nullptr as reference, references are never supposed to be null. clang-format the code for consistent styling. Signed-off-by: Eric Curtin <[email protected]>
1 parent af7747c commit f777a73

File tree

1 file changed

+45
-46
lines changed

1 file changed

+45
-46
lines changed

examples/run/run.cpp

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -323,25 +323,17 @@ class File {
323323
return 0;
324324
}
325325

326-
std::string read_all(const std::string & filename){
327-
open(filename, "r");
328-
lock();
329-
if (!file) {
330-
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
331-
return "";
332-
}
333-
326+
std::string to_string() {
334327
fseek(file, 0, SEEK_END);
335-
size_t size = ftell(file);
328+
const size_t size = ftell(file);
336329
fseek(file, 0, SEEK_SET);
337-
338330
std::string out;
339331
out.resize(size);
340-
size_t read_size = fread(&out[0], 1, size, file);
332+
const size_t read_size = fread(&out[0], 1, size, file);
341333
if (read_size != size) {
342-
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
343-
return "";
334+
printe("Error reading file: %s", strerror(errno));
344335
}
336+
345337
return out;
346338
}
347339

@@ -1098,59 +1090,66 @@ static int get_user_input(std::string & user_input, const std::string & user) {
10981090

10991091
// Reads a chat template file to be used
11001092
static std::string read_chat_template_file(const std::string & chat_template_file) {
1101-
if(chat_template_file.empty()){
1102-
return "";
1103-
}
1104-
11051093
File file;
1106-
std::string chat_template = "";
1107-
chat_template = file.read_all(chat_template_file);
1108-
if(chat_template.empty()){
1094+
if (!file.open(chat_template_file, "r")) {
11091095
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
11101096
return "";
11111097
}
1112-
return chat_template;
1098+
1099+
return file.to_string();
1100+
}
1101+
1102+
static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
1103+
const common_chat_templates_ptr & chat_templates, int & prev_len,
1104+
const bool stdout_a_terminal) {
1105+
add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
1106+
int new_len;
1107+
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
1108+
return 1;
1109+
}
1110+
1111+
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
1112+
std::string response;
1113+
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
1114+
return 1;
1115+
}
1116+
1117+
if (!opt.user.empty()) {
1118+
return 2;
1119+
}
1120+
1121+
add_message("assistant", response, llama_data);
1122+
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
1123+
return 1;
1124+
}
1125+
1126+
return 0;
11131127
}
11141128

11151129
// Main chat loop function
1116-
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
1130+
static int chat_loop(LlamaData & llama_data, const Opt & opt) {
11171131
int prev_len = 0;
11181132
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
1119-
1120-
std::string chat_template = "";
1121-
if(!chat_template_file.empty()){
1122-
chat_template = read_chat_template_file(chat_template_file);
1133+
std::string chat_template;
1134+
if (!opt.chat_template_file.empty()) {
1135+
chat_template = read_chat_template_file(opt.chat_template_file);
11231136
}
1124-
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
11251137

1138+
common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
11261139
static const bool stdout_a_terminal = is_stdout_a_terminal();
11271140
while (true) {
11281141
// Get user input
11291142
std::string user_input;
1130-
if (get_user_input(user_input, user) == 1) {
1143+
if (get_user_input(user_input, opt.user) == 1) {
11311144
return 0;
11321145
}
11331146

1134-
add_message("user", user.empty() ? user_input : user, llama_data);
1135-
int new_len;
1136-
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
1137-
return 1;
1138-
}
1139-
1140-
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
1141-
std::string response;
1142-
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
1147+
const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
1148+
if (ret == 1) {
11431149
return 1;
1144-
}
1145-
1146-
if (!user.empty()) {
1150+
} else if (ret == 2) {
11471151
break;
11481152
}
1149-
1150-
add_message("assistant", response, llama_data);
1151-
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
1152-
return 1;
1153-
}
11541153
}
11551154

11561155
return 0;
@@ -1208,7 +1207,7 @@ int main(int argc, const char ** argv) {
12081207
return 1;
12091208
}
12101209

1211-
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
1210+
if (chat_loop(llama_data, opt)) {
12121211
return 1;
12131212
}
12141213

0 commit comments

Comments
 (0)