Skip to content

Added --chat-template-file to llama-run #11922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class Opt {
llama_context_params ctx_params;
llama_model_params model_params;
std::string model_;
std::string chat_template_file;
std::string user;
bool use_jinja = false;
int context_size = -1, ngl = -1;
Expand Down Expand Up @@ -148,6 +149,16 @@ class Opt {
return 0;
}

int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
if (i + 1 >= argc) {
return 1;
}

option_value = argv[++i];

return 0;
}

int parse(int argc, const char ** argv) {
bool options_parsing = true;
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
Expand All @@ -169,6 +180,11 @@ class Opt {
verbose = true;
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
use_jinja = true;
} else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
return 1;
}
use_jinja = true;
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
help = true;
return 0;
Expand Down Expand Up @@ -207,6 +223,11 @@ class Opt {
"Options:\n"
" -c, --context-size <value>\n"
" Context size (default: %d)\n"
" --chat-template-file <path>\n"
" Path to the file containing the chat template to use with the model.\n"
" Only supports jinja templates and implicitly sets the --jinja flag.\n"
" --jinja\n"
" Use jinja templating for the chat template of the model\n"
" -n, -ngl, --ngl <value>\n"
" Number of GPU layers (default: %d)\n"
" --temp <value>\n"
Expand Down Expand Up @@ -261,13 +282,16 @@ static int get_terminal_width() {
#endif
}

#ifdef LLAMA_USE_CURL
class File {
public:
FILE * file = nullptr;

FILE * open(const std::string & filename, const char * mode) {
file = fopen(filename.c_str(), mode);
FILE* file = ggml_fopen(filename.c_str(), mode);
if (!file) {
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
return NULL;
}

return file;
}
Expand Down Expand Up @@ -303,6 +327,25 @@ class File {
return 0;
}

std::string read_all(const std::string & filename){
file = open(filename, "r");
if (!file) {
return "";
}

fseek(file, 0, SEEK_END);
size_t size = ftell(file);
fseek(file, 0, SEEK_SET);

std::string out;
size_t read_size = fread(&out, 1, size, file);
if (read_size != size) {
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
return "";
}
return out;
}

~File() {
if (fd >= 0) {
# ifdef _WIN32
Expand All @@ -327,6 +370,7 @@ class File {
# endif
};

#ifdef LLAMA_USE_CURL
class HttpClient {
public:
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
Expand Down Expand Up @@ -1053,11 +1097,32 @@ static int get_user_input(std::string & user_input, const std::string & user) {
return 0;
}

// Reads a chat template file to be used
static std::string read_chat_template_file(const std::string & chat_template_file) {
if(chat_template_file.empty()){
return "";
}

File file;
std::string chat_template = file.read_all(chat_template_file);
if(chat_template.empty()){
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
return "";
}
return chat_template;
}

// Main chat loop function
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");

std::string chat_template = "";
if(!chat_template_file.empty()){
chat_template = read_chat_template_file(chat_template_file);
}
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);

static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
Expand Down Expand Up @@ -1143,7 +1208,7 @@ int main(int argc, const char ** argv) {
return 1;
}

if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
return 1;
}

Expand Down
Loading