Skip to content

sync: minja #11352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions common/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,22 @@ class chat_template {
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments_ = false;
bool requires_typed_content_ = false;
bool supports_system_role_ = true;
bool supports_parallel_tool_calls_ = false;
std::string source_;
std::string bos_token_;
std::string eos_token_;
std::shared_ptr<minja::TemplateNode> template_root_;

std::string try_render(
std::string try_raw_render(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
try {
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false);
// fprintf(stderr, "Prompt: %s\n", prompt.c_str());
return prompt;
} catch (const std::exception & e) {
Expand All @@ -60,7 +61,7 @@ class chat_template {
supports_tools_ = source.find("tools") != std::string::npos;

auto renders_string_arguments =
try_render({
try_raw_render({
{
{"role", "user"},
{"content", "Hey"}
Expand All @@ -81,7 +82,7 @@ class chat_template {
}, {}, false).find("{\"code\": \"print") != std::string::npos;
if (!renders_string_arguments) {
auto renders_object_arguments =
try_render({
try_raw_render({
{
{"role", "user"},
{"content", "Hey"}
Expand All @@ -106,10 +107,13 @@ class chat_template {
}
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;

supports_system_role_ = try_render({
supports_system_role_ = try_raw_render({
{{"role", "system"}, {"content", "<System Needle>"}},
{{"role", "user"}, {"content", "Hey"}}
}, {}, false).find("<System Needle>") != std::string::npos;

requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos
&& try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos;
}

const std::string & source() const { return source_; }
Expand All @@ -122,19 +126,34 @@ class chat_template {
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
bool adjust_inputs = true) const
{
json actual_messages;

// First, "fix" messages so they have a chance to be rendered correctly by the template

if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) {
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
actual_messages = json::array();

auto add_message = [&](const json & msg) {
if (requires_typed_content_ && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
actual_messages.push_back({
{"role", msg.at("role")},
{"content", {{
{"type", "text"},
{"text", msg.at("content")},
}}},
});
} else {
actual_messages.push_back(msg);
}
};

std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
add_message({
{"role", "user"},
{"content", pending_system},
});
Expand Down Expand Up @@ -217,7 +236,7 @@ class chat_template {
}
}
}
actual_messages.push_back(message);
add_message(message);
}
flush_sys();
} else {
Expand Down
28 changes: 26 additions & 2 deletions common/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };

class TemplateToken {
public:
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter };
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter };

static std::string typeToString(Type t) {
switch (t) {
Expand All @@ -712,6 +712,8 @@ class TemplateToken {
case Type::EndMacro: return "endmacro";
case Type::Filter: return "filter";
case Type::EndFilter: return "endfilter";
case Type::Generation: return "generation";
case Type::EndGeneration: return "endgeneration";
}
return "Unknown";
}
Expand Down Expand Up @@ -788,6 +790,14 @@ struct EndForTemplateToken : public TemplateToken {
EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {}
};

struct GenerationTemplateToken : public TemplateToken {
GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {}
};

struct EndGenerationTemplateToken : public TemplateToken {
EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {}
};

struct SetTemplateToken : public TemplateToken {
std::string ns;
std::vector<std::string> var_names;
Expand Down Expand Up @@ -2149,7 +2159,7 @@ class Parser {
static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})");
static std::regex expr_open_regex(R"(\{\{([-~])?)");
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
Expand Down Expand Up @@ -2229,6 +2239,12 @@ class Parser {
} else if (keyword == "endfor") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndForTemplateToken>(location, pre_space, post_space));
} else if (keyword == "generation") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<GenerationTemplateToken>(location, pre_space, post_space));
} else if (keyword == "endgeneration") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
} else if (keyword == "set") {
static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");

Expand Down Expand Up @@ -2330,6 +2346,13 @@ class Parser {
throw unterminated(**start);
}
children.emplace_back(std::make_shared<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
} else if (dynamic_cast<GenerationTemplateToken*>(token.get())) {
auto body = parseTemplate(begin, it, end);
if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) {
throw unterminated(**start);
}
// Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking).
children.emplace_back(std::move(body));
} else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
Expand Down Expand Up @@ -2397,6 +2420,7 @@ class Parser {
|| dynamic_cast<EndFilterTemplateToken*>(token.get())
|| dynamic_cast<EndIfTemplateToken*>(token.get())
|| dynamic_cast<ElseTemplateToken*>(token.get())
|| dynamic_cast<EndGenerationTemplateToken*>(token.get())
|| dynamic_cast<ElifTemplateToken*>(token.get())) {
it--; // unconsume the token
break; // exit the loop
Expand Down
Loading