Skip to content

Commit 769cd71

Browse files
committed
tool-call: Phi-4 support
- Add system message if needed (per template requirement) - Add tools to system message (req'd by template) - Parse output: -- add tools to response when there is valid JSON between <|tool_call|> and </|tool_call|> -- content outside of tool_call tags is added to the text portion of the response -- if there is no valid JSON, the entire content is added to the text portion of the response
1 parent 1e2f78a commit 769cd71

File tree

5 files changed

+220
-1
lines changed

5 files changed

+220
-1
lines changed

common/chat.cpp

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ std::string common_chat_format_name(common_chat_format format) {
444444
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
445445
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
446446
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
447+
case COMMON_CHAT_FORMAT_PHI_4: return "Phi-4";
447448
default:
448449
throw std::runtime_error("Unknown chat format");
449450
}
@@ -1344,6 +1345,184 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
13441345
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
13451346
}
13461347

1348+
static common_chat_params common_chat_params_init_phi_4(const common_chat_template & tmpl, const struct templates_params & inputs) {
1349+
// Phi-4 has a unique format that expects tools in the system message with <|tool|> tags
1350+
// and returns function calls as a JSON object after <|tool_call|> tag
1351+
common_chat_params data;
1352+
1353+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1354+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1355+
std::vector<std::string> tool_rules;
1356+
std::vector<std::string> tool_call_alts;
1357+
foreach_function(inputs.tools, [&](const json & tool) {
1358+
const auto & function = tool.at("function");
1359+
std::string name = function.at("name");
1360+
auto parameters = function.at("parameters");
1361+
builder.resolve_refs(parameters);
1362+
tool_rules.push_back(builder.add_schema(name + "-call", {
1363+
{"type", "object"},
1364+
{"properties", {
1365+
{"name", {{"const", name}}},
1366+
{"arguments", parameters},
1367+
}},
1368+
{"required", json::array({"name", "arguments"})},
1369+
}));
1370+
});
1371+
auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
1372+
std::vector<std::string> alt_tags {
1373+
any_tool_call,
1374+
};
1375+
tool_call_alts.push_back(any_tool_call);
1376+
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
1377+
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
1378+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_call|>"});
1379+
data.preserved_tokens = {
1380+
"<|tool_call|>",
1381+
"</|tool_call|>",
1382+
};
1383+
});
1384+
1385+
// For Phi-4, we need to inject tools into the system message
1386+
// because the template expects tools in the system message with <|tool|> tags
1387+
if (inputs.tools.empty()) {
1388+
// No tools, use normal approach
1389+
data.prompt = apply(tmpl, inputs.messages, json::array(), inputs.add_generation_prompt);
1390+
} else {
1391+
// Make a copy of messages that we can modify
1392+
json adjusted_messages = inputs.messages;
1393+
1394+
// Extract just the function part of the OpenAI-formatted tools
1395+
json phi4_tools = json::array();
1396+
foreach_function(inputs.tools, [&](const json & tool) {
1397+
phi4_tools.push_back(tool.at("function"));
1398+
});
1399+
1400+
// Phi-4 template expects tools in the system message with <|tool|> tags.
1401+
// Find the system message, or add one if it doesn't exist
1402+
bool found_system_msg = false;
1403+
for (auto & message : adjusted_messages) {
1404+
if (message.contains("role") && message["role"] == "system") {
1405+
// Add tools to the existing system message and update content to mention tools
1406+
message["tools"] = phi4_tools;
1407+
1408+
// If the system message doesn't mention tools, append that information
1409+
std::string content = message["content"];
1410+
if (content.find("tool") == std::string::npos &&
1411+
content.find("function") == std::string::npos) {
1412+
message["content"] = content + " You have access to some tools.";
1413+
}
1414+
1415+
found_system_msg = true;
1416+
break;
1417+
}
1418+
}
1419+
1420+
// If no system message, add one with tools
1421+
if (!found_system_msg && !adjusted_messages.empty()) {
1422+
json system_msg = {
1423+
{"role", "system"},
1424+
{"content", "You are a helpful assistant with access to tools.\nTo use a tool, respond in this format: <|tool_call|>{\"name\": \"foo\", \"arguments\": {\"a\": 1}}<|/tool_call|>"},
1425+
{"tools", phi4_tools}
1426+
};
1427+
// Insert system message at the beginning
1428+
adjusted_messages.insert(adjusted_messages.begin(), system_msg);
1429+
}
1430+
1431+
// Apply template with tools embedded in system message, passing empty tools separately
1432+
data.prompt = apply(tmpl, adjusted_messages, json(), inputs.add_generation_prompt);
1433+
}
1434+
1435+
data.format = COMMON_CHAT_FORMAT_PHI_4;
1436+
return data;
1437+
}
1438+
1439+
static common_chat_msg common_chat_parse_phi_4(const std::string & input) {
1440+
common_chat_msg result;
1441+
result.role = "assistant";
1442+
1443+
std::string final_content = "";
1444+
1445+
const std::string opening_tag = "<|tool_call|>";
1446+
const std::string closing_tag = "</|tool_call|>";
1447+
1448+
size_t start_pos = 0;
1449+
while (true) {
1450+
// Find next tool call
1451+
size_t tool_start = input.find(opening_tag, start_pos);
1452+
if (tool_start == std::string::npos) {
1453+
// No more tool calls.
1454+
1455+
// Is start_pos within string bounds?
1456+
if (start_pos < input.length()) {
1457+
// Add the rest of the string to final_content
1458+
final_content += input.substr(start_pos);
1459+
}
1460+
break;
1461+
}
1462+
1463+
// Add content before the tool call to final_content
1464+
final_content += input.substr(start_pos, tool_start - start_pos);
1465+
1466+
// Find closing tag
1467+
size_t content_start = tool_start + opening_tag.length();
1468+
size_t tool_end = input.find(closing_tag, content_start);
1469+
1470+
if (tool_end == std::string::npos) {
1471+
// No closing tag found, so just include the rest of the string as tool.
1472+
tool_end = input.length();
1473+
}
1474+
1475+
// Extract tool call content
1476+
std::string tool_content = input.substr(
1477+
content_start,
1478+
tool_end - content_start
1479+
);
1480+
1481+
// Try to parse the tool call
1482+
try {
1483+
auto tool_call = json::parse(tool_content);
1484+
1485+
// Verify the required fields exist
1486+
if (!tool_call.contains("name")) {
1487+
throw std::runtime_error("Missing 'name' field in tool call");
1488+
}
1489+
1490+
if (!tool_call.contains("arguments")) {
1491+
throw std::runtime_error("Missing 'arguments' field in tool call");
1492+
}
1493+
1494+
std::string name = tool_call["name"].get<std::string>();
1495+
1496+
std::string arguments;
1497+
try {
1498+
arguments = tool_call["arguments"].dump();
1499+
} catch (const std::exception & e) {
1500+
LOG_ERR("Failed to serialize arguments: %s\n", e.what());
1501+
arguments = "{}";
1502+
}
1503+
1504+
result.tool_calls.push_back({
1505+
name,
1506+
arguments,
1507+
/* id= */ "",
1508+
});
1509+
} catch (const std::exception & e) {
1510+
// If parsing fails, include the entire tool call in the content
1511+
final_content += input.substr(
1512+
tool_start,
1513+
tool_end + closing_tag.length() - tool_start
1514+
);
1515+
}
1516+
1517+
// Move past this tool call for next iteration
1518+
start_pos = tool_end + closing_tag.length();
1519+
}
1520+
1521+
result.content = final_content;
1522+
return result;
1523+
}
1524+
1525+
13471526
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
13481527
common_chat_params data;
13491528
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
@@ -1622,6 +1801,11 @@ static common_chat_params common_chat_templates_apply_jinja(
16221801
return common_chat_params_init_firefunction_v2(tmpl, params);
16231802
}
16241803

1804+
// Phi-4 mini.
1805+
if (src.find("<|tool|>") != std::string::npos) {
1806+
return common_chat_params_init_phi_4(tmpl, params);
1807+
}
1808+
16251809
// Plain handler (no tools)
16261810
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
16271811
return common_chat_params_init_without_tools(tmpl, params);
@@ -1756,6 +1940,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
17561940
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
17571941
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
17581942
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
1943+
case COMMON_CHAT_FORMAT_PHI_4:
1944+
return common_chat_parse_phi_4(input);
17591945
default:
17601946
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
17611947
}

common/chat.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ enum common_chat_format {
5555
COMMON_CHAT_FORMAT_HERMES_2_PRO,
5656
COMMON_CHAT_FORMAT_COMMAND_R7B,
5757
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
58-
58+
COMMON_CHAT_FORMAT_PHI_4,
59+
5960
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
6061
};
6162

models/templates/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ These templates can be updated with the following commands:
1919
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
2020
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
2121
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
22+
./scripts/get_chat_template.py microsoft/Phi-4-mini-instruct > models/templates/microsoft-Phi-4-mini-instruct.jinja
2223
```
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{% for message in messages %}{% if message['role'] == 'system' and 'tools' in message and message['tools'] is not none %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|tool|>' + message['tools'] + '<|/tool|>' + '<|end|>' }}{% else %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|end|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}

tests/test-chat.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,36 @@ static void test_template_output_parsers() {
792792
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
793793
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
794794
}
795+
{
796+
auto tmpls = read_templates("models/templates/microsoft-Phi-4-mini-instruct.jinja");
797+
std::vector<std::string> end_tokens{ "<|end|>" };
798+
799+
assert_equals(COMMON_CHAT_FORMAT_PHI_4, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
800+
801+
// Test normal message without tools
802+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
803+
804+
// Test with content before tool call
805+
assert_msg_equals(
806+
common_chat_msg{"assistant", "I'll help with that.", {}, tool_calls, "", "", ""},
807+
common_chat_parse(
808+
"I'll help with that.<|tool_call|>{\"name\":\"special_function\",\"arguments\":{\"arg1\":1}}</|tool_call|>",
809+
COMMON_CHAT_FORMAT_PHI_4));
810+
811+
// Test with content after tool call
812+
assert_msg_equals(
813+
common_chat_msg{"assistant", "I'll help with that.", {}, tool_calls, "", "", ""},
814+
common_chat_parse(
815+
"<|tool_call|>{\"name\":\"special_function\",\"arguments\":{\"arg1\":1}}</|tool_call|>I'll help with that.",
816+
COMMON_CHAT_FORMAT_PHI_4));
817+
818+
// Test with newlines.
819+
assert_msg_equals(message_assist_call, common_chat_parse(
820+
"<|tool_call|>\n"
821+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
822+
"</|tool_call|>",
823+
COMMON_CHAT_FORMAT_PHI_4));
824+
}
795825
{
796826
auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
797827
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };

0 commit comments

Comments
 (0)