Skip to content

Commit b8d75ab

Browse files
committed
add support to format input json as typescript function str
1 parent 9e06741 commit b8d75ab

File tree

1 file changed

+101
-3
lines changed

1 file changed

+101
-3
lines changed

examples/server/utils.hpp

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
350350
//
351351

352352

353-
static std::string rubra_format_function_call_str(const std::vector<json> & functions, json & tool_name_map) {
353+
static std::string rubra_format_python_function_call_str(const std::vector<json> & functions, json & tool_name_map) {
354354
std::string final_str = "You have access to the following tools:\n";
355355
printf("rubra_format_function_call_str parsing...\n");
356356
json type_mapping = {
@@ -443,6 +443,104 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
443443
return final_str;
444444
}
445445

446+
447+
// Helper function to join strings with a delimiter
448+
static std::string helper_join(const std::vector<std::string>& elements, const std::string& delimiter) {
449+
std::string result;
450+
for (auto it = elements.begin(); it != elements.end(); ++it) {
451+
if (!result.empty()) {
452+
result += delimiter;
453+
}
454+
result += *it;
455+
}
456+
return result;
457+
}
458+
459+
static std::string rubra_format_typescript_function_call_str(const std::vector<json> &functions, json &tool_name_map) {
460+
std::string final_str = "You have access to the following tools:\n";
461+
json type_mapping = {
462+
{"string", "string"},
463+
{"integer", "number"},
464+
{"number", "number"},
465+
{"float", "number"},
466+
{"object", "any"},
467+
{"array", "any[]"},
468+
{"boolean", "boolean"},
469+
{"null", "null"}
470+
};
471+
472+
std::vector<std::string> function_definitions;
473+
for (const auto &function : functions) {
474+
const auto &spec = function.contains("function") ? function["function"] : function;
475+
std::string func_name = spec.value("name", "");
476+
if (func_name.find('-') != std::string::npos) {
477+
const std::string origin_func_name = func_name;
478+
std::replace(func_name.begin(), func_name.end(), '-', '_'); // replace "-" with "_" because - is invalid in typescript func name
479+
tool_name_map[func_name] = origin_func_name;
480+
}
481+
482+
const std::string description = spec.contains("description") ? spec["description"].get<std::string>() : "";
483+
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
484+
const auto& required_params = spec.contains("parameters") ? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();
485+
486+
std::vector<std::string> func_args;
487+
std::string docstring = "/**\n * " + description + "\n";
488+
489+
for (auto it = parameters.begin(); it != parameters.end(); ++it) {
490+
const std::string param = it.key();
491+
const json& details = it.value();
492+
std::string json_type = details["type"].get<std::string>();
493+
std::string ts_type = type_mapping.value(json_type, "any");
494+
std::string param_description = "";
495+
if (details.count("description") > 0) {
496+
param_description = details["description"]; // Assuming the description is the first element
497+
}
498+
if (details.count("enum") > 0) {
499+
std::string enum_values;
500+
for (const std::string val : details["enum"]) {
501+
if (!enum_values.empty()) {
502+
enum_values += " or ";
503+
}
504+
enum_values = enum_values+ "\"" + val + "\"";
505+
}
506+
if (details["enum"].size() == 1) {
507+
param_description += " Only Acceptable value is: " + enum_values;
508+
} else {
509+
param_description += " Only Acceptable values are: " + enum_values;
510+
}
511+
}
512+
if (param_description.empty()) {
513+
param_description = "No description provided.";
514+
}
515+
if (details.contains("enum")) {
516+
ts_type = "string"; // Enum is treated as string in typescript
517+
}
518+
std::string arg_str = param + ": " + ts_type;
519+
if (find(required_params.begin(), required_params.end(), param) == required_params.end()) {
520+
arg_str = param + "?: " + ts_type;
521+
docstring += " * @param " + param + " - " + param_description + "\n";
522+
} else {
523+
docstring += " * @param " + param + " - " + param_description + "\n";
524+
}
525+
func_args.push_back(arg_str);
526+
}
527+
docstring += " */\n";
528+
529+
std::string func_args_str = helper_join(func_args, ", ");
530+
std::string function_definition = docstring + "function " + func_name + "(" + func_args_str + "): any {}";
531+
532+
function_definitions.push_back(function_definition);
533+
}
534+
535+
for (const auto& def : function_definitions) {
536+
final_str += def + "\n\n";
537+
}
538+
final_str += "Use the following format if using tools:\n<<functions>>[toolname1(arg1=value1, arg2=value2, ...), toolname2(arg1=value1, arg2=value2, ...)]";
539+
return final_str;
540+
}
541+
542+
543+
446544
static std::string default_tool_formatter(const std::vector<json>& tools) {
447545
std::string toolText = "";
448546
std::vector<std::string> toolNames;
@@ -504,12 +602,12 @@ static json oaicompat_completion_params_parse(
504602

505603
if (body.contains("tools") && !body["tools"].empty()) {
506604
// function_str = default_tool_formatter(body["tool"]);
507-
function_str = rubra_format_function_call_str(body["tools"], tool_name_map);
605+
function_str = rubra_format_typescript_function_call_str(body["tools"], tool_name_map);
508606
}
509607
// If 'tool' is not set or empty, check 'functions'
510608
else if (body.contains("functions") && !body["functions"].empty()) {
511609
// function_str = default_tool_formatter(body["functions"]);
512-
function_str = rubra_format_function_call_str(body["functions"], tool_name_map);
610+
function_str = rubra_format_typescript_function_call_str(body["functions"], tool_name_map);
513611
}
514612
printf("\n=============Formatting Input from OPENAI format...============\n");
515613
if (function_str != "") {

0 commit comments

Comments
 (0)