@@ -323,25 +323,17 @@ class File {
323
323
return 0 ;
324
324
}
325
325
326
- std::string read_all (const std::string & filename){
327
- open (filename, " r" );
328
- lock ();
329
- if (!file) {
330
- printe (" Error opening file '%s': %s" , filename.c_str (), strerror (errno));
331
- return " " ;
332
- }
333
-
326
+ std::string to_string () {
334
327
fseek (file, 0 , SEEK_END);
335
- size_t size = ftell (file);
328
+ const size_t size = ftell (file);
336
329
fseek (file, 0 , SEEK_SET);
337
-
338
330
std::string out;
339
331
out.resize (size);
340
- size_t read_size = fread (&out[0 ], 1 , size, file);
332
+ const size_t read_size = fread (&out[0 ], 1 , size, file);
341
333
if (read_size != size) {
342
- printe (" Error reading file '%s': %s" , filename.c_str (), strerror (errno));
343
- return " " ;
334
+ printe (" Error reading file: %s" , strerror (errno));
344
335
}
336
+
345
337
return out;
346
338
}
347
339
@@ -1098,59 +1090,66 @@ static int get_user_input(std::string & user_input, const std::string & user) {
1098
1090
1099
1091
// Reads a chat template file to be used
1100
1092
static std::string read_chat_template_file (const std::string & chat_template_file) {
1101
- if (chat_template_file.empty ()){
1102
- return " " ;
1103
- }
1104
-
1105
1093
File file;
1106
- std::string chat_template = " " ;
1107
- chat_template = file.read_all (chat_template_file);
1108
- if (chat_template.empty ()){
1094
+ if (!file.open (chat_template_file, " r" )) {
1109
1095
printe (" Error opening chat template file '%s': %s" , chat_template_file.c_str (), strerror (errno));
1110
1096
return " " ;
1111
1097
}
1112
- return chat_template;
1098
+
1099
+ return file.to_string ();
1100
+ }
1101
+
1102
+ static int process_user_message (const Opt & opt, const std::string & user_input, LlamaData & llama_data,
1103
+ const common_chat_templates_ptr & chat_templates, int & prev_len,
1104
+ const bool stdout_a_terminal) {
1105
+ add_message (" user" , opt.user .empty () ? user_input : opt.user , llama_data);
1106
+ int new_len;
1107
+ if (apply_chat_template_with_error_handling (chat_templates.get (), llama_data, true , new_len, opt.use_jinja ) < 0 ) {
1108
+ return 1 ;
1109
+ }
1110
+
1111
+ std::string prompt (llama_data.fmtted .begin () + prev_len, llama_data.fmtted .begin () + new_len);
1112
+ std::string response;
1113
+ if (generate_response (llama_data, prompt, response, stdout_a_terminal)) {
1114
+ return 1 ;
1115
+ }
1116
+
1117
+ if (!opt.user .empty ()) {
1118
+ return 2 ;
1119
+ }
1120
+
1121
+ add_message (" assistant" , response, llama_data);
1122
+ if (apply_chat_template_with_error_handling (chat_templates.get (), llama_data, false , prev_len, opt.use_jinja ) < 0 ) {
1123
+ return 1 ;
1124
+ }
1125
+
1126
+ return 0 ;
1113
1127
}
1114
1128
1115
1129
// Main chat loop function
1116
- static int chat_loop (LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja ) {
1130
+ static int chat_loop (LlamaData & llama_data, const Opt & opt ) {
1117
1131
int prev_len = 0 ;
1118
1132
llama_data.fmtted .resize (llama_n_ctx (llama_data.context .get ()));
1119
-
1120
- std::string chat_template = " " ;
1121
- if (!chat_template_file.empty ()){
1122
- chat_template = read_chat_template_file (chat_template_file);
1133
+ std::string chat_template;
1134
+ if (!opt.chat_template_file .empty ()) {
1135
+ chat_template = read_chat_template_file (opt.chat_template_file );
1123
1136
}
1124
- auto chat_templates = common_chat_templates_init (llama_data.model .get (), chat_template.empty () ? nullptr : chat_template);
1125
1137
1138
+ common_chat_templates_ptr chat_templates = common_chat_templates_init (llama_data.model .get (), chat_template);
1126
1139
static const bool stdout_a_terminal = is_stdout_a_terminal ();
1127
1140
while (true ) {
1128
1141
// Get user input
1129
1142
std::string user_input;
1130
- if (get_user_input (user_input, user) == 1 ) {
1143
+ if (get_user_input (user_input, opt. user ) == 1 ) {
1131
1144
return 0 ;
1132
1145
}
1133
1146
1134
- add_message (" user" , user.empty () ? user_input : user, llama_data);
1135
- int new_len;
1136
- if (apply_chat_template_with_error_handling (chat_templates.get (), llama_data, true , new_len, use_jinja) < 0 ) {
1137
- return 1 ;
1138
- }
1139
-
1140
- std::string prompt (llama_data.fmtted .begin () + prev_len, llama_data.fmtted .begin () + new_len);
1141
- std::string response;
1142
- if (generate_response (llama_data, prompt, response, stdout_a_terminal)) {
1147
+ const int ret = process_user_message (opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
1148
+ if (ret == 1 ) {
1143
1149
return 1 ;
1144
- }
1145
-
1146
- if (!user.empty ()) {
1150
+ } else if (ret == 2 ) {
1147
1151
break ;
1148
1152
}
1149
-
1150
- add_message (" assistant" , response, llama_data);
1151
- if (apply_chat_template_with_error_handling (chat_templates.get (), llama_data, false , prev_len, use_jinja) < 0 ) {
1152
- return 1 ;
1153
- }
1154
1153
}
1155
1154
1156
1155
return 0 ;
@@ -1208,7 +1207,7 @@ int main(int argc, const char ** argv) {
1208
1207
return 1 ;
1209
1208
}
1210
1209
1211
- if (chat_loop (llama_data, opt. user , opt. chat_template_file , opt. use_jinja )) {
1210
+ if (chat_loop (llama_data, opt)) {
1212
1211
return 1 ;
1213
1212
}
1214
1213
0 commit comments