@@ -113,6 +113,7 @@ class Opt {
113
113
llama_context_params ctx_params;
114
114
llama_model_params model_params;
115
115
std::string model_;
116
+ std::string chat_template_file;
116
117
std::string user;
117
118
bool use_jinja = false ;
118
119
int context_size = -1 , ngl = -1 ;
@@ -148,6 +149,16 @@ class Opt {
148
149
return 0 ;
149
150
}
150
151
152
+ int handle_option_with_value (int argc, const char ** argv, int & i, std::string & option_value) {
153
+ if (i + 1 >= argc) {
154
+ return 1 ;
155
+ }
156
+
157
+ option_value = argv[++i];
158
+
159
+ return 0 ;
160
+ }
161
+
151
162
int parse (int argc, const char ** argv) {
152
163
bool options_parsing = true ;
153
164
for (int i = 1 , positional_args_i = 0 ; i < argc; ++i) {
@@ -169,6 +180,11 @@ class Opt {
169
180
verbose = true ;
170
181
} else if (options_parsing && strcmp (argv[i], " --jinja" ) == 0 ) {
171
182
use_jinja = true ;
183
+ } else if (options_parsing && strcmp (argv[i], " --chat-template-file" ) == 0 ){
184
+ if (handle_option_with_value (argc, argv, i, chat_template_file) == 1 ) {
185
+ return 1 ;
186
+ }
187
+ use_jinja = true ;
172
188
} else if (options_parsing && parse_flag (argv, i, " -h" , " --help" )) {
173
189
help = true ;
174
190
return 0 ;
@@ -207,6 +223,11 @@ class Opt {
207
223
" Options:\n "
208
224
" -c, --context-size <value>\n "
209
225
" Context size (default: %d)\n "
226
+ " --chat-template-file <path>\n "
227
+ " Path to the file containing the chat template to use with the model.\n "
228
+ " Only supports jinja templates and implicitly sets the --jinja flag.\n "
229
+ " --jinja\n "
230
+ " Use jinja templating for the chat template of the model\n "
210
231
" -n, -ngl, --ngl <value>\n "
211
232
" Number of GPU layers (default: %d)\n "
212
233
" --temp <value>\n "
@@ -261,13 +282,12 @@ static int get_terminal_width() {
261
282
#endif
262
283
}
263
284
264
- #ifdef LLAMA_USE_CURL
265
285
class File {
266
286
public:
267
287
FILE * file = nullptr ;
268
288
269
289
FILE * open (const std::string & filename, const char * mode) {
270
- file = fopen (filename.c_str (), mode);
290
+ file = ggml_fopen (filename.c_str (), mode);
271
291
272
292
return file;
273
293
}
@@ -303,6 +323,28 @@ class File {
303
323
return 0 ;
304
324
}
305
325
326
+ std::string read_all (const std::string & filename){
327
+ open (filename, " r" );
328
+ lock ();
329
+ if (!file) {
330
+ printe (" Error opening file '%s': %s" , filename.c_str (), strerror (errno));
331
+ return " " ;
332
+ }
333
+
334
+ fseek (file, 0 , SEEK_END);
335
+ size_t size = ftell (file);
336
+ fseek (file, 0 , SEEK_SET);
337
+
338
+ std::string out;
339
+ out.resize (size);
340
+ size_t read_size = fread (&out[0 ], 1 , size, file);
341
+ if (read_size != size) {
342
+ printe (" Error reading file '%s': %s" , filename.c_str (), strerror (errno));
343
+ return " " ;
344
+ }
345
+ return out;
346
+ }
347
+
306
348
~File () {
307
349
if (fd >= 0 ) {
308
350
# ifdef _WIN32
@@ -327,6 +369,7 @@ class File {
327
369
# endif
328
370
};
329
371
372
+ #ifdef LLAMA_USE_CURL
330
373
class HttpClient {
331
374
public:
332
375
int init (const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
@@ -1053,11 +1096,33 @@ static int get_user_input(std::string & user_input, const std::string & user) {
1053
1096
return 0 ;
1054
1097
}
1055
1098
1099
+ // Reads a chat template file to be used
1100
+ static std::string read_chat_template_file (const std::string & chat_template_file) {
1101
+ if (chat_template_file.empty ()){
1102
+ return " " ;
1103
+ }
1104
+
1105
+ File file;
1106
+ std::string chat_template = " " ;
1107
+ chat_template = file.read_all (chat_template_file);
1108
+ if (chat_template.empty ()){
1109
+ printe (" Error opening chat template file '%s': %s" , chat_template_file.c_str (), strerror (errno));
1110
+ return " " ;
1111
+ }
1112
+ return chat_template;
1113
+ }
1114
+
1056
1115
// Main chat loop function
1057
- static int chat_loop (LlamaData & llama_data, const std::string & user, bool use_jinja) {
1116
+ static int chat_loop (LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
1058
1117
int prev_len = 0 ;
1059
1118
llama_data.fmtted .resize (llama_n_ctx (llama_data.context .get ()));
1060
- auto chat_templates = common_chat_templates_init (llama_data.model .get (), " " );
1119
+
1120
+ std::string chat_template = " " ;
1121
+ if (!chat_template_file.empty ()){
1122
+ chat_template = read_chat_template_file (chat_template_file);
1123
+ }
1124
+ auto chat_templates = common_chat_templates_init (llama_data.model .get (), chat_template.empty () ? nullptr : chat_template);
1125
+
1061
1126
static const bool stdout_a_terminal = is_stdout_a_terminal ();
1062
1127
while (true ) {
1063
1128
// Get user input
@@ -1143,7 +1208,7 @@ int main(int argc, const char ** argv) {
1143
1208
return 1 ;
1144
1209
}
1145
1210
1146
- if (chat_loop (llama_data, opt.user , opt.use_jinja )) {
1211
+ if (chat_loop (llama_data, opt.user , opt.chat_template_file , opt. use_jinja )) {
1147
1212
return 1 ;
1148
1213
}
1149
1214
0 commit comments