Skip to content

Added --chat-template-file to llama-run #11961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class Opt {
llama_context_params ctx_params;
llama_model_params model_params;
std::string model_;
std::string chat_template_file;
std::string user;
bool use_jinja = false;
int context_size = -1, ngl = -1;
Expand Down Expand Up @@ -148,6 +149,16 @@ class Opt {
return 0;
}

int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
if (i + 1 >= argc) {
return 1;
}

option_value = argv[++i];

return 0;
}

int parse(int argc, const char ** argv) {
bool options_parsing = true;
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
Expand All @@ -169,6 +180,11 @@ class Opt {
verbose = true;
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
use_jinja = true;
} else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
return 1;
}
use_jinja = true;
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
help = true;
return 0;
Expand Down Expand Up @@ -207,6 +223,11 @@ class Opt {
"Options:\n"
" -c, --context-size <value>\n"
" Context size (default: %d)\n"
" --chat-template-file <path>\n"
" Path to the file containing the chat template to use with the model.\n"
" Only supports jinja templates and implicitly sets the --jinja flag.\n"
" --jinja\n"
" Use jinja templating for the chat template of the model\n"
" -n, -ngl, --ngl <value>\n"
" Number of GPU layers (default: %d)\n"
" --temp <value>\n"
Expand Down Expand Up @@ -261,13 +282,12 @@ static int get_terminal_width() {
#endif
}

#ifdef LLAMA_USE_CURL
class File {
public:
FILE * file = nullptr;

FILE * open(const std::string & filename, const char * mode) {
file = fopen(filename.c_str(), mode);
file = ggml_fopen(filename.c_str(), mode);

return file;
}
Expand Down Expand Up @@ -303,6 +323,28 @@ class File {
return 0;
}

std::string read_all(const std::string & filename){
open(filename, "r");
lock();
if (!file) {
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
return "";
}

fseek(file, 0, SEEK_END);
size_t size = ftell(file);
fseek(file, 0, SEEK_SET);

std::string out;
out.resize(size);
size_t read_size = fread(&out[0], 1, size, file);
if (read_size != size) {
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
return "";
}
return out;
}

~File() {
if (fd >= 0) {
# ifdef _WIN32
Expand All @@ -327,6 +369,7 @@ class File {
# endif
};

#ifdef LLAMA_USE_CURL
class HttpClient {
public:
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
Expand Down Expand Up @@ -1053,11 +1096,33 @@ static int get_user_input(std::string & user_input, const std::string & user) {
return 0;
}

// Reads a chat template file to be used
static std::string read_chat_template_file(const std::string & chat_template_file) {
if(chat_template_file.empty()){
return "";
}

File file;
std::string chat_template = "";
chat_template = file.read_all(chat_template_file);
if(chat_template.empty()){
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
return "";
}
return chat_template;
}

// Main chat loop function
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");

std::string chat_template = "";
if(!chat_template_file.empty()){
chat_template = read_chat_template_file(chat_template_file);
}
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
Copy link
Collaborator

@ericcurtin ericcurtin Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do:

chat_template.empty() ? "" : chat_template

here. Passing nullptr to a reference is not allowed. I wish the compiler caught these things.

common_chat_templates_ptr common_chat_templates_init(
                                    const struct llama_model * model,
                                           const std::string & chat_template_override,
                                           const std::string & bos_token_override = "",
                                           const std::string & eos_token_override = "")

Copy link
Collaborator

@ericcurtin ericcurtin Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the std::string class is smart enough to interpret all these as the same thing:

"", '', 0, NULL, nullptr

and that's why it compiles/works 🤷 . So it might be just implicitly converting it to "".


static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
Expand Down Expand Up @@ -1143,7 +1208,7 @@ int main(int argc, const char ** argv) {
return 1;
}

if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
return 1;
}

Expand Down
Loading