Skip to content

Commit 60a01b3

Browse files
authored
Hacky func streaming (#1)
* hacky function call streaming * remove * minor fix to take care of case that the input function has no description or arguments is null * test parser * fix makefile to make sure the order of file linking works for ubuntu gcc/g++ 11.4 * add function name mapping to take care of input function name with hyphen- * add a comment TODO for streaming chunks.
1 parent 1eafdc9 commit 60a01b3

File tree

5 files changed

+309
-321
lines changed

5 files changed

+309
-321
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
753753

754754
server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o scanner.o parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
755755
$(CXX) $(CXXFLAGS) -c $< -I examples/server/tree_sitter -o $(call GET_OBJ_FILE, $<)
756-
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
756+
$(CXX) $(CXXFLAGS) $(call GET_OBJ_FILE, $<) $(filter-out %.h %.hpp $<,$^) -Iexamples/server -o $@ $(LDFLAGS) $(LWINSOCK2)
757757

758758
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
759759
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

examples/server/python-parser.hpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ static json parseValue(const std::string& content) {
4242

4343

4444
// Recursive function to parse and create JSON for the outer function calls
45-
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
45+
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, json tool_name_map, uint32_t indent = 0) {
4646
auto type = ts_node_type(node);
4747

4848
// printf("type: %s\n", type);
@@ -60,8 +60,14 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
6060
TSNode argumentsNode = ts_node_child(node, 1); // The arguments node
6161

6262
// Extract the function name
63-
call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
63+
std::string func_name = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
64+
if (tool_name_map.find(func_name) != tool_name_map.end()){
65+
call["name"] = tool_name_map[func_name];
66+
} else {
67+
call["name"] = func_name;
68+
}
6469

70+
printf("function name: %s\n", call["name"].dump().c_str());
6571
unsigned int numArgs = ts_node_named_child_count(argumentsNode);
6672
for (unsigned int i = 0; i < numArgs; ++i) {
6773
TSNode argNode = ts_node_named_child(argumentsNode, i);
@@ -94,11 +100,11 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
94100
unsigned int numChildren = ts_node_child_count(node);
95101
for (unsigned int i = 0; i < numChildren; ++i) {
96102
TSNode child = ts_node_child(node, i);
97-
parseFunctionCalls(child, calls, source_code, indent+1);
103+
parseFunctionCalls(child, calls, source_code, tool_name_map, indent+1);
98104
}
99105
}
100106

101-
static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
107+
static std::vector<json> parsePythonFunctionCalls(std::string source_string, json tool_name_map) {
102108
// Parse Python function calls from the source code and return a JSON array
103109
std::vector<json> calls;
104110
std::string delimiter = "<<functions>>";
@@ -124,7 +130,7 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
124130
return calls;
125131
}
126132

127-
parseFunctionCalls(root_node, calls, source_code_cstr, 0);
133+
parseFunctionCalls(root_node, calls, source_code_cstr,tool_name_map, 0);
128134

129135
ts_tree_delete(tree);
130136
ts_parser_delete(parser);

examples/server/server.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3230,7 +3230,6 @@ int main(int argc, char ** argv) {
32303230
const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
32313231
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
32323232
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
3233-
32343233
const int id_task = ctx_server.queue_tasks.get_new_id();
32353234

32363235
ctx_server.queue_results.add_waiting_task_id(id_task);
@@ -3249,12 +3248,26 @@ int main(int argc, char ** argv) {
32493248
}
32503249
ctx_server.queue_results.remove_waiting_task_id(id_task);
32513250
} else {
3252-
const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
3251+
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, data](size_t, httplib::DataSink & sink) {
3252+
std::string all_content = "";
32533253
while (true) {
32543254
server_task_result result = ctx_server.queue_results.recv(id_task);
3255-
if (!result.error) {
3256-
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
32573255

3256+
std::string this_content = json_value(result.data, "content", std::string(""));
3257+
// TODO: this block is just a hacky solution to enable function calling in streaming -- by concat the streaming chunks.
3258+
// Ideally: If the first a few tokens is <<functions>>, it should keep waiting for all chunks, otherwise do normal stream logic.
3259+
if (this_content != "") {
3260+
all_content += this_content;
3261+
continue;
3262+
} else {
3263+
if (all_content != "") {
3264+
result.data["content"] = all_content;
3265+
all_content = "";
3266+
}
3267+
}
3268+
3269+
if (!result.error) {
3270+
std::vector<json> result_array = format_partial_response_oaicompat(data, result.data, completion_id);
32583271
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
32593272
if (!it->empty()) {
32603273
const std::string str =

examples/server/utils.hpp

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <vector>
1111
#include <sstream>
1212
#include <random>
13+
#include <unordered_map>
14+
#include <algorithm>
1315

1416
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
1517

@@ -337,8 +339,9 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
337339
//
338340

339341

340-
static std::string rubra_format_function_call_str(const std::vector<json> & functions) {
342+
static std::string rubra_format_function_call_str(const std::vector<json> & functions, json & tool_name_map) {
341343
std::string final_str = "You have access to the following tools:\n";
344+
printf("rubra_format_function_call_str parsing...\n");
342345
json type_mapping = {
343346
{"string", "str"},
344347
{"integer", "int"},
@@ -352,10 +355,15 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
352355
std::vector<std::string> function_definitions;
353356
for (const auto & function : functions) {
354357
const auto &spec = function.contains("function") ? function["function"] : function;
355-
const std::string func_name = spec.value("name", "");
356-
const std::string description = spec.value("description", "");
357-
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
358-
const auto& required_params = spec.contains("parameters") ? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
358+
std::string func_name = spec.value("name", "");
359+
if (func_name.find('-') != std::string::npos) {
360+
const std::string origin_func_name = func_name;
361+
std::replace(func_name.begin(), func_name.end(), '-', '_'); // replace "-" with "_" because - is invalid in python func name
362+
tool_name_map[func_name] = origin_func_name;
363+
}
364+
const std::string description = spec.contains("description") ? spec["description"].get<std::string>() : "";
365+
const auto& parameters = spec.contains("parameters") && spec["parameters"].contains("properties")? spec["parameters"].value("properties", json({})) : json({});
366+
const auto& required_params = spec.contains("parameters") && spec["parameters"].contains("properties")? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
359367

360368
std::vector<std::string> func_args;
361369
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
@@ -481,15 +489,16 @@ static json oaicompat_completion_params_parse(
481489
llama_params["__oaicompat"] = true;
482490

483491
std::string function_str = "";
492+
json tool_name_map;
484493

485494
if (body.contains("tools") && !body["tools"].empty()) {
486495
// function_str = default_tool_formatter(body["tool"]);
487-
function_str = rubra_format_function_call_str(body["tools"]);
496+
function_str = rubra_format_function_call_str(body["tools"], tool_name_map);
488497
}
489498
// If 'tool' is not set or empty, check 'functions'
490499
else if (body.contains("functions") && !body["functions"].empty()) {
491500
// function_str = default_tool_formatter(body["functions"]);
492-
function_str = rubra_format_function_call_str(body["functions"]);
501+
function_str = rubra_format_function_call_str(body["functions"], tool_name_map);
493502
}
494503
printf("\n=============Formatting Input from OPENAI format...============\n");
495504
if (function_str != "") {
@@ -607,6 +616,7 @@ static json oaicompat_completion_params_parse(
607616
else {
608617
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
609618
}
619+
llama_params["tool_name_map"] = tool_name_map;
610620

611621
// Map OpenAI parameters to llama.cpp parameters
612622
//
@@ -661,9 +671,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
661671
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
662672
std::string content = json_value(result, "content", std::string(""));
663673

664-
std::vector<json> parsed_content = parsePythonFunctionCalls(content);
674+
std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
665675

666-
667676
std::string finish_reason = "length";
668677
if (stopped_word || stopped_eos) {
669678
finish_reason = "stop";
@@ -732,7 +741,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
732741
}
733742

734743
// return value is vector as there is one case where we might need to generate two responses
735-
static std::vector<json> format_partial_response_oaicompat(json result, const std::string & completion_id) {
744+
static std::vector<json> format_partial_response_oaicompat(json request ,json result, const std::string & completion_id) {
736745
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
737746
return std::vector<json>({result});
738747
}
@@ -745,6 +754,66 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
745754
bool stopped_limit = json_value(result, "stopped_limit", false);
746755
std::string content = json_value(result, "content", std::string(""));
747756

757+
std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
758+
std::time_t t = std::time(0);
759+
if (!parsed_content.empty()) {
760+
std::vector<json> res;
761+
json choices1 = json::array({json{{"finish_reason", nullptr},
762+
{"index", 0},
763+
{"delta", json{{"role", "assistant"}}}}});
764+
765+
json ret = json{
766+
{"choices", choices1},
767+
{"created", t},
768+
{"id", completion_id},
769+
{"model", modelname},
770+
{"object", "chat.completion.chunk"}
771+
};
772+
res.push_back(ret);
773+
774+
for (size_t i = 0; i < parsed_content.size(); ++i) {
775+
const auto &pc = parsed_content[i];
776+
// Use 'pc' and 'i' as needed
777+
json tool_call1;
778+
tool_call1["id"] = pc["id"];
779+
tool_call1["type"] = "function";
780+
tool_call1["index"] = i;
781+
tool_call1["function"] = json{
782+
{"name" , pc["name"]},
783+
{"arguments" , ""},
784+
};
785+
json ret1 = json{
786+
{"choices", json::array({json{{"finish_reason", nullptr},
787+
{"index", 0},
788+
{"delta", json{{"tool_calls", std::vector<json>{tool_call1}}}}}})
789+
},
790+
{"created", t},
791+
{"id", completion_id},
792+
{"model", modelname},
793+
{"object", "chat.completion.chunk"}
794+
};
795+
res.push_back(ret1);
796+
json tool_call2;
797+
tool_call2["index"] = i;
798+
tool_call2["function"] = json{
799+
{"name" , ""},
800+
{"arguments" , pc["kwargs"].dump()},
801+
};
802+
json ret2 = json{
803+
{"choices", json::array({json{{"finish_reason", nullptr},
804+
{"index", 0},
805+
{"delta", json{{"tool_calls", std::vector<json>{tool_call2}}}}}})
806+
},
807+
{"created", t},
808+
{"id", completion_id},
809+
{"model", modelname},
810+
{"object", "chat.completion.chunk"}
811+
};
812+
res.push_back(ret2);
813+
}
814+
return res;
815+
}
816+
748817
std::string finish_reason;
749818
if (stopped_word || stopped_eos) {
750819
finish_reason = "stop";
@@ -753,7 +822,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
753822
finish_reason = "length";
754823
}
755824

756-
std::time_t t = std::time(0);
825+
757826

758827
json choices;
759828

0 commit comments

Comments
 (0)