Skip to content

Commit 4e3fcc9

Browse files
committed
Hacky func streaming (#1)
* hacky function call streaming * remove * minor fix to take care of case that the input function has no description or arguments is null * test parser * fix makefile to make sure the order of file linking works for ubuntu gcc/g++ 11.4 * add function name mapping to take care of input function name with hyphen- * add a comment TODO for streaming chunks.
1 parent 270684a commit 4e3fcc9

File tree

5 files changed

+309
-321
lines changed

5 files changed

+309
-321
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
808808

809809
server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o scanner.o parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
810810
$(CXX) $(CXXFLAGS) -c $< -I examples/server/tree_sitter -o $(call GET_OBJ_FILE, $<)
811-
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
811+
$(CXX) $(CXXFLAGS) $(call GET_OBJ_FILE, $<) $(filter-out %.h %.hpp $<,$^) -Iexamples/server -o $@ $(LDFLAGS) $(LWINSOCK2)
812812

813813
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
814814
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

examples/server/python-parser.hpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ static json parseValue(const std::string& content) {
4242

4343

4444
// Recursive function to parse and create JSON for the outer function calls
45-
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
45+
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, json tool_name_map, uint32_t indent = 0) {
4646
auto type = ts_node_type(node);
4747

4848
// printf("type: %s\n", type);
@@ -60,8 +60,14 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
6060
TSNode argumentsNode = ts_node_child(node, 1); // The arguments node
6161

6262
// Extract the function name
63-
call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
63+
std::string func_name = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
64+
if (tool_name_map.find(func_name) != tool_name_map.end()){
65+
call["name"] = tool_name_map[func_name];
66+
} else {
67+
call["name"] = func_name;
68+
}
6469

70+
printf("function name: %s\n", call["name"].dump().c_str());
6571
unsigned int numArgs = ts_node_named_child_count(argumentsNode);
6672
for (unsigned int i = 0; i < numArgs; ++i) {
6773
TSNode argNode = ts_node_named_child(argumentsNode, i);
@@ -94,11 +100,11 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
94100
unsigned int numChildren = ts_node_child_count(node);
95101
for (unsigned int i = 0; i < numChildren; ++i) {
96102
TSNode child = ts_node_child(node, i);
97-
parseFunctionCalls(child, calls, source_code, indent+1);
103+
parseFunctionCalls(child, calls, source_code, tool_name_map, indent+1);
98104
}
99105
}
100106

101-
static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
107+
static std::vector<json> parsePythonFunctionCalls(std::string source_string, json tool_name_map) {
102108
// Parse Python function calls from the source code and return a JSON array
103109
std::vector<json> calls;
104110
std::string delimiter = "<<functions>>";
@@ -124,7 +130,7 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
124130
return calls;
125131
}
126132

127-
parseFunctionCalls(root_node, calls, source_code_cstr, 0);
133+
parseFunctionCalls(root_node, calls, source_code_cstr,tool_name_map, 0);
128134

129135
ts_tree_delete(tree);
130136
ts_parser_delete(parser);

examples/server/server.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3504,7 +3504,6 @@ int main(int argc, char ** argv) {
35043504
const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
35053505
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
35063506
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
3507-
35083507
const int id_task = ctx_server.queue_tasks.get_new_id();
35093508

35103509
ctx_server.queue_results.add_waiting_task_id(id_task);
@@ -3523,12 +3522,26 @@ int main(int argc, char ** argv) {
35233522
}
35243523
ctx_server.queue_results.remove_waiting_task_id(id_task);
35253524
} else {
3526-
const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
3525+
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, data](size_t, httplib::DataSink & sink) {
3526+
std::string all_content = "";
35273527
while (true) {
35283528
server_task_result result = ctx_server.queue_results.recv(id_task);
3529-
if (!result.error) {
3530-
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
35313529

3530+
std::string this_content = json_value(result.data, "content", std::string(""));
3531+
// TODO: this block is just a hacky solution to enable function calling in streaming -- by concat the streaming chunks.
3532+
// Ideally: If the first a few tokens is <<functions>>, it should keep waiting for all chunks, otherwise do normal stream logic.
3533+
if (this_content != "") {
3534+
all_content += this_content;
3535+
continue;
3536+
} else {
3537+
if (all_content != "") {
3538+
result.data["content"] = all_content;
3539+
all_content = "";
3540+
}
3541+
}
3542+
3543+
if (!result.error) {
3544+
std::vector<json> result_array = format_partial_response_oaicompat(data, result.data, completion_id);
35323545
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
35333546
if (!it->empty()) {
35343547
const std::string str =

examples/server/utils.hpp

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <vector>
1111
#include <sstream>
1212
#include <random>
13+
#include <unordered_map>
14+
#include <algorithm>
1315

1416
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
1517

@@ -348,8 +350,9 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
348350
//
349351

350352

351-
static std::string rubra_format_function_call_str(const std::vector<json> & functions) {
353+
static std::string rubra_format_function_call_str(const std::vector<json> & functions, json & tool_name_map) {
352354
std::string final_str = "You have access to the following tools:\n";
355+
printf("rubra_format_function_call_str parsing...\n");
353356
json type_mapping = {
354357
{"string", "str"},
355358
{"integer", "int"},
@@ -363,10 +366,15 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
363366
std::vector<std::string> function_definitions;
364367
for (const auto & function : functions) {
365368
const auto &spec = function.contains("function") ? function["function"] : function;
366-
const std::string func_name = spec.value("name", "");
367-
const std::string description = spec.value("description", "");
368-
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
369-
const auto& required_params = spec.contains("parameters") ? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
369+
std::string func_name = spec.value("name", "");
370+
if (func_name.find('-') != std::string::npos) {
371+
const std::string origin_func_name = func_name;
372+
std::replace(func_name.begin(), func_name.end(), '-', '_'); // replace "-" with "_" because - is invalid in python func name
373+
tool_name_map[func_name] = origin_func_name;
374+
}
375+
const std::string description = spec.contains("description") ? spec["description"].get<std::string>() : "";
376+
const auto& parameters = spec.contains("parameters") && spec["parameters"].contains("properties")? spec["parameters"].value("properties", json({})) : json({});
377+
const auto& required_params = spec.contains("parameters") && spec["parameters"].contains("properties")? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
370378

371379
std::vector<std::string> func_args;
372380
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
@@ -492,15 +500,16 @@ static json oaicompat_completion_params_parse(
492500
llama_params["__oaicompat"] = true;
493501

494502
std::string function_str = "";
503+
json tool_name_map;
495504

496505
if (body.contains("tools") && !body["tools"].empty()) {
497506
// function_str = default_tool_formatter(body["tool"]);
498-
function_str = rubra_format_function_call_str(body["tools"]);
507+
function_str = rubra_format_function_call_str(body["tools"], tool_name_map);
499508
}
500509
// If 'tool' is not set or empty, check 'functions'
501510
else if (body.contains("functions") && !body["functions"].empty()) {
502511
// function_str = default_tool_formatter(body["functions"]);
503-
function_str = rubra_format_function_call_str(body["functions"]);
512+
function_str = rubra_format_function_call_str(body["functions"], tool_name_map);
504513
}
505514
printf("\n=============Formatting Input from OPENAI format...============\n");
506515
if (function_str != "") {
@@ -618,6 +627,7 @@ static json oaicompat_completion_params_parse(
618627
else {
619628
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
620629
}
630+
llama_params["tool_name_map"] = tool_name_map;
621631

622632
// Map OpenAI parameters to llama.cpp parameters
623633
//
@@ -705,9 +715,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
705715
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
706716
std::string content = json_value(result, "content", std::string(""));
707717

708-
std::vector<json> parsed_content = parsePythonFunctionCalls(content);
718+
std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
709719

710-
711720
std::string finish_reason = "length";
712721
if (stopped_word || stopped_eos) {
713722
finish_reason = "stop";
@@ -776,7 +785,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
776785
}
777786

778787
// return value is vector as there is one case where we might need to generate two responses
779-
static std::vector<json> format_partial_response_oaicompat(json result, const std::string & completion_id) {
788+
static std::vector<json> format_partial_response_oaicompat(json request ,json result, const std::string & completion_id) {
780789
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
781790
return std::vector<json>({result});
782791
}
@@ -789,6 +798,66 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
789798
bool stopped_limit = json_value(result, "stopped_limit", false);
790799
std::string content = json_value(result, "content", std::string(""));
791800

801+
std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
802+
std::time_t t = std::time(0);
803+
if (!parsed_content.empty()) {
804+
std::vector<json> res;
805+
json choices1 = json::array({json{{"finish_reason", nullptr},
806+
{"index", 0},
807+
{"delta", json{{"role", "assistant"}}}}});
808+
809+
json ret = json{
810+
{"choices", choices1},
811+
{"created", t},
812+
{"id", completion_id},
813+
{"model", modelname},
814+
{"object", "chat.completion.chunk"}
815+
};
816+
res.push_back(ret);
817+
818+
for (size_t i = 0; i < parsed_content.size(); ++i) {
819+
const auto &pc = parsed_content[i];
820+
// Use 'pc' and 'i' as needed
821+
json tool_call1;
822+
tool_call1["id"] = pc["id"];
823+
tool_call1["type"] = "function";
824+
tool_call1["index"] = i;
825+
tool_call1["function"] = json{
826+
{"name" , pc["name"]},
827+
{"arguments" , ""},
828+
};
829+
json ret1 = json{
830+
{"choices", json::array({json{{"finish_reason", nullptr},
831+
{"index", 0},
832+
{"delta", json{{"tool_calls", std::vector<json>{tool_call1}}}}}})
833+
},
834+
{"created", t},
835+
{"id", completion_id},
836+
{"model", modelname},
837+
{"object", "chat.completion.chunk"}
838+
};
839+
res.push_back(ret1);
840+
json tool_call2;
841+
tool_call2["index"] = i;
842+
tool_call2["function"] = json{
843+
{"name" , ""},
844+
{"arguments" , pc["kwargs"].dump()},
845+
};
846+
json ret2 = json{
847+
{"choices", json::array({json{{"finish_reason", nullptr},
848+
{"index", 0},
849+
{"delta", json{{"tool_calls", std::vector<json>{tool_call2}}}}}})
850+
},
851+
{"created", t},
852+
{"id", completion_id},
853+
{"model", modelname},
854+
{"object", "chat.completion.chunk"}
855+
};
856+
res.push_back(ret2);
857+
}
858+
return res;
859+
}
860+
792861
std::string finish_reason;
793862
if (stopped_word || stopped_eos) {
794863
finish_reason = "stop";
@@ -797,7 +866,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
797866
finish_reason = "length";
798867
}
799868

800-
std::time_t t = std::time(0);
869+
801870

802871
json choices;
803872

0 commit comments

Comments
 (0)