-
Notifications
You must be signed in to change notification settings - Fork 12.2k
Add chat template support for llama-cli #8068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5a2fde8
c91f972
3174527
962be6a
43cab6b
a3dbfab
a1e9520
7a76502
c530ce4
a28e70f
895bb2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,12 +39,12 @@ static std::ostringstream * g_output_ss; | |
static std::vector<llama_token> * g_output_tokens; | ||
static bool is_interacting = false; | ||
|
||
static bool file_exists(const std::string &path) { | ||
static bool file_exists(const std::string & path) { | ||
std::ifstream f(path.c_str()); | ||
return f.good(); | ||
} | ||
|
||
static bool file_is_empty(const std::string &path) { | ||
static bool file_is_empty(const std::string & path) { | ||
std::ifstream f; | ||
f.exceptions(std::ifstream::failbit | std::ifstream::badbit); | ||
f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate); | ||
|
@@ -117,6 +117,14 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v | |
LOG_TEE("%s", text); | ||
} | ||
|
||
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) { | ||
llama_chat_msg new_msg{role, content}; | ||
auto formatted = llama_chat_format_single( | ||
model, g_params->chat_template, chat_msgs, new_msg, role == "user"); | ||
chat_msgs.push_back({role, content}); | ||
return formatted; | ||
} | ||
|
||
int main(int argc, char ** argv) { | ||
gpt_params params; | ||
g_params = ¶ms; | ||
|
@@ -190,6 +198,7 @@ int main(int argc, char ** argv) { | |
llama_model * model; | ||
llama_context * ctx; | ||
llama_context * ctx_guidance = NULL; | ||
std::vector<llama_chat_msg> chat_msgs; | ||
g_model = &model; | ||
g_ctx = &ctx; | ||
|
||
|
@@ -215,6 +224,8 @@ int main(int argc, char ** argv) { | |
__func__, n_ctx_train, n_ctx); | ||
} | ||
|
||
LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, params.chat_template).c_str()); | ||
|
||
// print system information | ||
{ | ||
LOG_TEE("\n"); | ||
|
@@ -249,16 +260,21 @@ int main(int argc, char ** argv) { | |
|
||
std::vector<llama_token> embd_inp; | ||
|
||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { | ||
LOG("tokenize the prompt\n"); | ||
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); | ||
} else { | ||
LOG("use session tokens\n"); | ||
embd_inp = session_tokens; | ||
} | ||
{ | ||
auto prompt = params.conversation | ||
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode | ||
: params.prompt; | ||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { | ||
LOG("tokenize the prompt\n"); | ||
embd_inp = ::llama_tokenize(ctx, prompt, true, true); | ||
} else { | ||
LOG("use session tokens\n"); | ||
embd_inp = session_tokens; | ||
} | ||
|
||
LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); | ||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); | ||
LOG("prompt: \"%s\"\n", log_tostr(prompt)); | ||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); | ||
} | ||
|
||
// Should not run without any tokens | ||
if (embd_inp.empty()) { | ||
|
@@ -478,6 +494,7 @@ int main(int argc, char ** argv) { | |
std::vector<int> input_tokens; g_input_tokens = &input_tokens; | ||
std::vector<int> output_tokens; g_output_tokens = &output_tokens; | ||
std::ostringstream output_ss; g_output_ss = &output_ss; | ||
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode | ||
|
||
// the first thing we will do is to output the prompt, so set color accordingly | ||
console::set_display(console::prompt); | ||
|
@@ -793,11 +810,18 @@ int main(int argc, char ** argv) { | |
is_antiprompt = true; | ||
} | ||
|
||
chat_add_and_format(model, chat_msgs, "system", assistant_ss.str()); | ||
is_interacting = true; | ||
printf("\n"); | ||
} | ||
} | ||
|
||
// if current token is not EOG, we add it to current assistant message | ||
if (params.conversation) { | ||
auto id = llama_sampling_last(ctx_sampling); | ||
assistant_ss << llama_token_to_piece(ctx, id, false); | ||
} | ||
|
||
if (n_past > 0 && is_interacting) { | ||
LOG("waiting for user input\n"); | ||
|
||
|
@@ -848,8 +872,12 @@ int main(int argc, char ** argv) { | |
string_process_escapes(buffer); | ||
} | ||
|
||
std::string user_inp = params.conversation | ||
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) | ||
: std::move(buffer); | ||
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When Regarding the comment - can you illustrate with an example as I'm not sure what is the issue There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An example would be a prompt like this: Some models having Leaving There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aha got it. Yes, for now let's make have the simple solution |
||
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); | ||
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false); | ||
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation); | ||
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); | ||
|
||
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); | ||
|
@@ -864,6 +892,9 @@ int main(int argc, char ** argv) { | |
output_ss << llama_token_to_piece(ctx, token); | ||
} | ||
|
||
// reset assistant message | ||
assistant_ss.str(""); | ||
|
||
n_remain -= line_inp.size(); | ||
LOG("n_remain: %d\n", n_remain); | ||
} else { | ||
|
Uh oh!
There was an error while loading. Please reload this page.