|
17 | 17 | #include <string>
|
18 | 18 | #include <vector>
|
19 | 19 |
|
| 20 | +using json = nlohmann::ordered_json; |
| 21 | + |
20 | 22 | static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
|
21 | 23 | auto time = std::chrono::system_clock::to_time_t(now);
|
22 | 24 | auto local_time = *std::localtime(&time);
|
@@ -721,16 +723,23 @@ static void foreach_function(const json & tools, const std::function<void(const
|
721 | 723 |
|
722 | 724 | static std::string apply(
|
723 | 725 | const common_chat_template & tmpl,
|
724 |
| - const nlohmann::ordered_json & messages, |
725 |
| - const nlohmann::ordered_json & tools, |
726 |
| - bool add_generation_prompt, |
727 |
| - const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) |
| 726 | + const struct templates_params & inputs, |
| 727 | + const std::optional<json> & messages_override = std::nullopt, |
| 728 | + const std::optional<json> & tools_override = std::nullopt, |
| 729 | + const std::optional<json> & additional_context = std::nullopt) |
728 | 730 | {
|
729 | 731 | minja::chat_template_inputs tmpl_inputs;
|
730 |
| - tmpl_inputs.messages = messages; |
731 |
| - tmpl_inputs.tools = tools; |
732 |
| - tmpl_inputs.add_generation_prompt = add_generation_prompt; |
733 |
| - tmpl_inputs.extra_context = extra_context; |
| 732 | + tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages; |
| 733 | + if (tools_override) { |
| 734 | + tmpl_inputs.tools = *tools_override; |
| 735 | + } else { |
| 736 | + tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools; |
| 737 | + } |
| 738 | + tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; |
| 739 | + tmpl_inputs.extra_context = inputs.extra_context; |
| 740 | + if (additional_context) { |
| 741 | + tmpl_inputs.extra_context.merge_patch(*additional_context); |
| 742 | + } |
734 | 743 | // TODO: add flag to control date/time, if only for testing purposes.
|
735 | 744 | // tmpl_inputs.now = std::chrono::system_clock::now();
|
736 | 745 |
|
@@ -829,7 +838,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
829 | 838 | inputs.messages,
|
830 | 839 | "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
831 | 840 |
|
832 |
| - data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, inputs.extra_context); |
| 841 | + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); |
833 | 842 | data.format = COMMON_CHAT_FORMAT_GENERIC;
|
834 | 843 | return data;
|
835 | 844 | }
|
@@ -905,7 +914,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
905 | 914 | data.preserved_tokens = {
|
906 | 915 | "[TOOL_CALLS]",
|
907 | 916 | };
|
908 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 917 | + data.prompt = apply(tmpl, inputs); |
909 | 918 | data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
910 | 919 | return data;
|
911 | 920 | }
|
@@ -935,7 +944,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
935 | 944 | adjusted_messages.push_back(msg);
|
936 | 945 | }
|
937 | 946 | }
|
938 |
| - data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); |
| 947 | + data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); |
939 | 948 | data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
|
940 | 949 | if (string_ends_with(data.prompt, "<|START_THINKING|>")) {
|
941 | 950 | if (!inputs.enable_thinking) {
|
@@ -1123,7 +1132,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
|
1123 | 1132 | } else {
|
1124 | 1133 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
1125 | 1134 | }
|
1126 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { |
| 1135 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json { |
1127 | 1136 | {"date_string", format_time(inputs.now, "%d %b %Y")},
|
1128 | 1137 | {"tools_in_user_message", false},
|
1129 | 1138 | {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
@@ -1188,7 +1197,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
|
1188 | 1197 |
|
1189 | 1198 | static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1190 | 1199 | common_chat_params data;
|
1191 |
| - auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1200 | + auto prompt = apply(tmpl, inputs); |
1192 | 1201 |
|
1193 | 1202 | // Hacks to fix the official (broken) prompt.
|
1194 | 1203 | // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
|
@@ -1283,7 +1292,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
1283 | 1292 | static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1284 | 1293 | LOG_DBG("%s\n", __func__);
|
1285 | 1294 | common_chat_params data;
|
1286 |
| - data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { |
| 1295 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json { |
1287 | 1296 | {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
1288 | 1297 | {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
1289 | 1298 | });
|
@@ -1339,7 +1348,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
1339 | 1348 | // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
1340 | 1349 | // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
|
1341 | 1350 | common_chat_params data;
|
1342 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1351 | + data.prompt = apply(tmpl, inputs); |
1343 | 1352 | data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
1344 | 1353 | if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
1345 | 1354 | data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
@@ -1466,7 +1475,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
1466 | 1475 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
1467 | 1476 | }
|
1468 | 1477 |
|
1469 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1478 | + data.prompt = apply(tmpl, inputs); |
1470 | 1479 | // TODO: if (has_raw_python)
|
1471 | 1480 | return data;
|
1472 | 1481 | }
|
@@ -1499,11 +1508,9 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
|
1499 | 1508 | static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1500 | 1509 | common_chat_params data;
|
1501 | 1510 |
|
1502 |
| - json additional_context = { |
| 1511 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json { |
1503 | 1512 | {"enable_thinking", inputs.enable_thinking},
|
1504 |
| - }; |
1505 |
| - |
1506 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context); |
| 1513 | + }); |
1507 | 1514 | data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
1508 | 1515 | if (string_ends_with(data.prompt, "<think>\n")) {
|
1509 | 1516 | if (!inputs.enable_thinking) {
|
@@ -1692,7 +1699,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
|
1692 | 1699 |
|
1693 | 1700 | static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1694 | 1701 | common_chat_params data;
|
1695 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, inputs.extra_context); |
| 1702 | + data.prompt = apply(tmpl, inputs); |
1696 | 1703 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
1697 | 1704 | data.grammar_lazy = false;
|
1698 | 1705 | if (!inputs.json_schema.is_null()) {
|
|
0 commit comments