Skip to content

Hacky func streaming #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C

server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o scanner.o parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -I examples/server/tree_sitter -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
$(CXX) $(CXXFLAGS) $(call GET_OBJ_FILE, $<) $(filter-out %.h %.hpp $<,$^) -Iexamples/server -o $@ $(LDFLAGS) $(LWINSOCK2)

gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
Expand Down
16 changes: 11 additions & 5 deletions examples/server/python-parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ static json parseValue(const std::string& content) {


// Recursive function to parse and create JSON for the outer function calls
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, uint32_t indent = 0) {
static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, const char* source_code, json tool_name_map, uint32_t indent = 0) {
auto type = ts_node_type(node);

// printf("type: %s\n", type);
Expand All @@ -60,8 +60,14 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
TSNode argumentsNode = ts_node_child(node, 1); // The arguments node

// Extract the function name
call["name"] = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
std::string func_name = std::string(source_code + ts_node_start_byte(functionNode), ts_node_end_byte(functionNode) - ts_node_start_byte(functionNode));
if (tool_name_map.find(func_name) != tool_name_map.end()){
call["name"] = tool_name_map[func_name];
} else {
call["name"] = func_name;
}

printf("function name: %s\n", call["name"].dump().c_str());
unsigned int numArgs = ts_node_named_child_count(argumentsNode);
for (unsigned int i = 0; i < numArgs; ++i) {
TSNode argNode = ts_node_named_child(argumentsNode, i);
Expand Down Expand Up @@ -94,11 +100,11 @@ static void parseFunctionCalls(const TSNode& node, std::vector<json>& calls, con
unsigned int numChildren = ts_node_child_count(node);
for (unsigned int i = 0; i < numChildren; ++i) {
TSNode child = ts_node_child(node, i);
parseFunctionCalls(child, calls, source_code, indent+1);
parseFunctionCalls(child, calls, source_code, tool_name_map, indent+1);
}
}

static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
static std::vector<json> parsePythonFunctionCalls(std::string source_string, json tool_name_map) {
// Parse Python function calls from the source code and return a JSON array
std::vector<json> calls;
std::string delimiter = "<<functions>>";
Expand All @@ -124,7 +130,7 @@ static std::vector<json> parsePythonFunctionCalls(std::string source_string) {
return calls;
}

parseFunctionCalls(root_node, calls, source_code_cstr, 0);
parseFunctionCalls(root_node, calls, source_code_cstr,tool_name_map, 0);

ts_tree_delete(tree);
ts_parser_delete(parser);
Expand Down
21 changes: 17 additions & 4 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3230,7 +3230,6 @@ int main(int argc, char ** argv) {
const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);

const int id_task = ctx_server.queue_tasks.get_new_id();

ctx_server.queue_results.add_waiting_task_id(id_task);
Expand All @@ -3249,12 +3248,26 @@ int main(int argc, char ** argv) {
}
ctx_server.queue_results.remove_waiting_task_id(id_task);
} else {
const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, data](size_t, httplib::DataSink & sink) {
std::string all_content = "";
while (true) {
server_task_result result = ctx_server.queue_results.recv(id_task);
if (!result.error) {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);

std::string this_content = json_value(result.data, "content", std::string(""));
// TODO: this block is just a hacky solution to enable function calling in streaming -- by concat the streaming chunks.
// Ideally: If the first a few tokens is <<functions>>, it should keep waiting for all chunks, otherwise do normal stream logic.
if (this_content != "") {
all_content += this_content;
continue;
} else {
if (all_content != "") {
result.data["content"] = all_content;
all_content = "";
}
}

if (!result.error) {
std::vector<json> result_array = format_partial_response_oaicompat(data, result.data, completion_id);
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
if (!it->empty()) {
const std::string str =
Expand Down
91 changes: 80 additions & 11 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <vector>
#include <sstream>
#include <random>
#include <unordered_map>
#include <algorithm>

#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"

Expand Down Expand Up @@ -337,8 +339,9 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
//


static std::string rubra_format_function_call_str(const std::vector<json> & functions) {
static std::string rubra_format_function_call_str(const std::vector<json> & functions, json & tool_name_map) {
std::string final_str = "You have access to the following tools:\n";
printf("rubra_format_function_call_str parsing...\n");
json type_mapping = {
{"string", "str"},
{"integer", "int"},
Expand All @@ -352,10 +355,15 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
std::vector<std::string> function_definitions;
for (const auto & function : functions) {
const auto &spec = function.contains("function") ? function["function"] : function;
const std::string func_name = spec.value("name", "");
const std::string description = spec.value("description", "");
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
const auto& required_params = spec.contains("parameters") ? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
std::string func_name = spec.value("name", "");
if (func_name.find('-') != std::string::npos) {
const std::string origin_func_name = func_name;
std::replace(func_name.begin(), func_name.end(), '-', '_'); // replace "-" with "_" because - is invalid in python func name
tool_name_map[func_name] = origin_func_name;
}
const std::string description = spec.contains("description") ? spec["description"].get<std::string>() : "";
const auto& parameters = spec.contains("parameters") && spec["parameters"].contains("properties")? spec["parameters"].value("properties", json({})) : json({});
const auto& required_params = spec.contains("parameters") && spec["parameters"].contains("properties")? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();

std::vector<std::string> func_args;
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
Expand Down Expand Up @@ -481,15 +489,16 @@ static json oaicompat_completion_params_parse(
llama_params["__oaicompat"] = true;

std::string function_str = "";
json tool_name_map;

if (body.contains("tools") && !body["tools"].empty()) {
// function_str = default_tool_formatter(body["tool"]);
function_str = rubra_format_function_call_str(body["tools"]);
function_str = rubra_format_function_call_str(body["tools"], tool_name_map);
}
// If 'tool' is not set or empty, check 'functions'
else if (body.contains("functions") && !body["functions"].empty()) {
// function_str = default_tool_formatter(body["functions"]);
function_str = rubra_format_function_call_str(body["functions"]);
function_str = rubra_format_function_call_str(body["functions"], tool_name_map);
}
printf("\n=============Formatting Input from OPENAI format...============\n");
if (function_str != "") {
Expand Down Expand Up @@ -607,6 +616,7 @@ static json oaicompat_completion_params_parse(
else {
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
}
llama_params["tool_name_map"] = tool_name_map;

// Map OpenAI parameters to llama.cpp parameters
//
Expand Down Expand Up @@ -661,9 +671,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));

std::vector<json> parsed_content = parsePythonFunctionCalls(content);
std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);


std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
finish_reason = "stop";
Expand Down Expand Up @@ -732,7 +741,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
}

// return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(json result, const std::string & completion_id) {
static std::vector<json> format_partial_response_oaicompat(json request ,json result, const std::string & completion_id) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result});
}
Expand All @@ -745,6 +754,66 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));

std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
std::time_t t = std::time(0);
if (!parsed_content.empty()) {
std::vector<json> res;
json choices1 = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});

json ret = json{
{"choices", choices1},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
res.push_back(ret);

for (size_t i = 0; i < parsed_content.size(); ++i) {
const auto &pc = parsed_content[i];
// Use 'pc' and 'i' as needed
json tool_call1;
tool_call1["id"] = pc["id"];
tool_call1["type"] = "function";
tool_call1["index"] = i;
tool_call1["function"] = json{
{"name" , pc["name"]},
{"arguments" , ""},
};
json ret1 = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"tool_calls", std::vector<json>{tool_call1}}}}}})
},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
res.push_back(ret1);
json tool_call2;
tool_call2["index"] = i;
tool_call2["function"] = json{
{"name" , ""},
{"arguments" , pc["kwargs"].dump()},
};
json ret2 = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"tool_calls", std::vector<json>{tool_call2}}}}}})
},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
res.push_back(ret2);
}
return res;
}

std::string finish_reason;
if (stopped_word || stopped_eos) {
finish_reason = "stop";
Expand All @@ -753,7 +822,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
finish_reason = "length";
}

std::time_t t = std::time(0);


json choices;

Expand Down
Loading