Skip to content

Commit 9a93863

Browse files
ochafikmatteoserva
authored andcommitted
chat.cpp: simplify calls to apply to ensure systematic propagation of extra_context (+ the odd existing additional_context)
1 parent 67789ef commit 9a93863

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

common/chat.cpp

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <string>
1818
#include <vector>
1919

20+
using json = nlohmann::ordered_json;
21+
2022
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
2123
auto time = std::chrono::system_clock::to_time_t(now);
2224
auto local_time = *std::localtime(&time);
@@ -721,16 +723,23 @@ static void foreach_function(const json & tools, const std::function<void(const
721723

722724
static std::string apply(
723725
const common_chat_template & tmpl,
724-
const nlohmann::ordered_json & messages,
725-
const nlohmann::ordered_json & tools,
726-
bool add_generation_prompt,
727-
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
726+
const struct templates_params & inputs,
727+
const std::optional<json> & messages_override = std::nullopt,
728+
const std::optional<json> & tools_override = std::nullopt,
729+
const std::optional<json> & additional_context = std::nullopt)
728730
{
729731
minja::chat_template_inputs tmpl_inputs;
730-
tmpl_inputs.messages = messages;
731-
tmpl_inputs.tools = tools;
732-
tmpl_inputs.add_generation_prompt = add_generation_prompt;
733-
tmpl_inputs.extra_context = extra_context;
732+
tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages;
733+
if (tools_override) {
734+
tmpl_inputs.tools = *tools_override;
735+
} else {
736+
tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools;
737+
}
738+
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
739+
tmpl_inputs.extra_context = inputs.extra_context;
740+
if (additional_context) {
741+
tmpl_inputs.extra_context.merge_patch(*additional_context);
742+
}
734743
// TODO: add flag to control date/time, if only for testing purposes.
735744
// tmpl_inputs.now = std::chrono::system_clock::now();
736745

@@ -829,7 +838,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
829838
inputs.messages,
830839
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
831840

832-
data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, inputs.extra_context);
841+
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
833842
data.format = COMMON_CHAT_FORMAT_GENERIC;
834843
return data;
835844
}
@@ -905,7 +914,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
905914
data.preserved_tokens = {
906915
"[TOOL_CALLS]",
907916
};
908-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
917+
data.prompt = apply(tmpl, inputs);
909918
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
910919
return data;
911920
}
@@ -935,7 +944,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
935944
adjusted_messages.push_back(msg);
936945
}
937946
}
938-
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
947+
data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
939948
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
940949
if (string_ends_with(data.prompt, "<|START_THINKING|>")) {
941950
if (!inputs.enable_thinking) {
@@ -1123,7 +1132,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
11231132
} else {
11241133
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
11251134
}
1126-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
1135+
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json {
11271136
{"date_string", format_time(inputs.now, "%d %b %Y")},
11281137
{"tools_in_user_message", false},
11291138
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
@@ -1188,7 +1197,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
11881197

11891198
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
11901199
common_chat_params data;
1191-
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1200+
auto prompt = apply(tmpl, inputs);
11921201

11931202
// Hacks to fix the official (broken) prompt.
11941203
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
@@ -1283,7 +1292,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
12831292
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
12841293
LOG_DBG("%s\n", __func__);
12851294
common_chat_params data;
1286-
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
1295+
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json {
12871296
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
12881297
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
12891298
});
@@ -1339,7 +1348,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
13391348
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
13401349
// If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
13411350
common_chat_params data;
1342-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1351+
data.prompt = apply(tmpl, inputs);
13431352
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
13441353
if (inputs.tools.is_array() && !inputs.tools.empty()) {
13451354
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -1466,7 +1475,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
14661475
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
14671476
}
14681477

1469-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1478+
data.prompt = apply(tmpl, inputs);
14701479
// TODO: if (has_raw_python)
14711480
return data;
14721481
}
@@ -1499,11 +1508,9 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
14991508
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
15001509
common_chat_params data;
15011510

1502-
json additional_context = {
1511+
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json {
15031512
{"enable_thinking", inputs.enable_thinking},
1504-
};
1505-
1506-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context);
1513+
});
15071514
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
15081515
if (string_ends_with(data.prompt, "<think>\n")) {
15091516
if (!inputs.enable_thinking) {
@@ -1692,7 +1699,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16921699

16931700
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
16941701
common_chat_params data;
1695-
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, inputs.extra_context);
1702+
data.prompt = apply(tmpl, inputs);
16961703
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
16971704
data.grammar_lazy = false;
16981705
if (!inputs.json_schema.is_null()) {

0 commit comments

Comments
 (0)