Skip to content

Commit 6b21c3b

Browse files
engelmimglambda
authored andcommitted
run : add --chat-template-file (ggml-org#11961)
Relates to: ggml-org#11178 Added --chat-template-file CLI option to llama-run. If specified, the file will be read and the content passed for overwriting the chat template of the model to common_chat_templates_from_model. Signed-off-by: Michael Engel <[email protected]>
1 parent 5b389e6 commit 6b21c3b

File tree

1 file changed

+70
-5
lines changed

1 file changed

+70
-5
lines changed

examples/run/run.cpp

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class Opt {
113113
llama_context_params ctx_params;
114114
llama_model_params model_params;
115115
std::string model_;
116+
std::string chat_template_file;
116117
std::string user;
117118
bool use_jinja = false;
118119
int context_size = -1, ngl = -1;
@@ -148,6 +149,16 @@ class Opt {
148149
return 0;
149150
}
150151

152+
int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
153+
if (i + 1 >= argc) {
154+
return 1;
155+
}
156+
157+
option_value = argv[++i];
158+
159+
return 0;
160+
}
161+
151162
int parse(int argc, const char ** argv) {
152163
bool options_parsing = true;
153164
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
@@ -169,6 +180,11 @@ class Opt {
169180
verbose = true;
170181
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
171182
use_jinja = true;
183+
} else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
184+
if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
185+
return 1;
186+
}
187+
use_jinja = true;
172188
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
173189
help = true;
174190
return 0;
@@ -207,6 +223,11 @@ class Opt {
207223
"Options:\n"
208224
" -c, --context-size <value>\n"
209225
" Context size (default: %d)\n"
226+
" --chat-template-file <path>\n"
227+
" Path to the file containing the chat template to use with the model.\n"
228+
" Only supports jinja templates and implicitly sets the --jinja flag.\n"
229+
" --jinja\n"
230+
" Use jinja templating for the chat template of the model\n"
210231
" -n, -ngl, --ngl <value>\n"
211232
" Number of GPU layers (default: %d)\n"
212233
" --temp <value>\n"
@@ -261,13 +282,12 @@ static int get_terminal_width() {
261282
#endif
262283
}
263284

264-
#ifdef LLAMA_USE_CURL
265285
class File {
266286
public:
267287
FILE * file = nullptr;
268288

269289
FILE * open(const std::string & filename, const char * mode) {
270-
file = fopen(filename.c_str(), mode);
290+
file = ggml_fopen(filename.c_str(), mode);
271291

272292
return file;
273293
}
@@ -303,6 +323,28 @@ class File {
303323
return 0;
304324
}
305325

326+
std::string read_all(const std::string & filename){
327+
open(filename, "r");
328+
lock();
329+
if (!file) {
330+
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
331+
return "";
332+
}
333+
334+
fseek(file, 0, SEEK_END);
335+
size_t size = ftell(file);
336+
fseek(file, 0, SEEK_SET);
337+
338+
std::string out;
339+
out.resize(size);
340+
size_t read_size = fread(&out[0], 1, size, file);
341+
if (read_size != size) {
342+
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
343+
return "";
344+
}
345+
return out;
346+
}
347+
306348
~File() {
307349
if (fd >= 0) {
308350
# ifdef _WIN32
@@ -327,6 +369,7 @@ class File {
327369
# endif
328370
};
329371

372+
#ifdef LLAMA_USE_CURL
330373
class HttpClient {
331374
public:
332375
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
@@ -1053,11 +1096,33 @@ static int get_user_input(std::string & user_input, const std::string & user) {
10531096
return 0;
10541097
}
10551098

1099+
// Reads a chat template file to be used
1100+
static std::string read_chat_template_file(const std::string & chat_template_file) {
1101+
if(chat_template_file.empty()){
1102+
return "";
1103+
}
1104+
1105+
File file;
1106+
std::string chat_template = "";
1107+
chat_template = file.read_all(chat_template_file);
1108+
if(chat_template.empty()){
1109+
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
1110+
return "";
1111+
}
1112+
return chat_template;
1113+
}
1114+
10561115
// Main chat loop function
1057-
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
1116+
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
10581117
int prev_len = 0;
10591118
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
1060-
auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
1119+
1120+
std::string chat_template = "";
1121+
if(!chat_template_file.empty()){
1122+
chat_template = read_chat_template_file(chat_template_file);
1123+
}
1124+
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
1125+
10611126
static const bool stdout_a_terminal = is_stdout_a_terminal();
10621127
while (true) {
10631128
// Get user input
@@ -1143,7 +1208,7 @@ int main(int argc, const char ** argv) {
11431208
return 1;
11441209
}
11451210

1146-
if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
1211+
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
11471212
return 1;
11481213
}
11491214

0 commit comments

Comments
 (0)