10
10
#include < vector>
11
11
#include < sstream>
12
12
#include < random>
13
+ #include < unordered_map>
14
+ #include < algorithm>
13
15
14
16
#define DEFAULT_OAICOMPAT_MODEL " gpt-3.5-turbo-0613"
15
17
@@ -337,8 +339,9 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
337
339
//
338
340
339
341
340
- static std::string rubra_format_function_call_str (const std::vector<json> & functions) {
342
+ static std::string rubra_format_function_call_str (const std::vector<json> & functions, json & tool_name_map ) {
341
343
std::string final_str = " You have access to the following tools:\n " ;
344
+ printf (" rubra_format_function_call_str parsing...\n " );
342
345
json type_mapping = {
343
346
{" string" , " str" },
344
347
{" integer" , " int" },
@@ -352,10 +355,15 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
352
355
std::vector<std::string> function_definitions;
353
356
for (const auto & function : functions) {
354
357
const auto &spec = function.contains (" function" ) ? function[" function" ] : function;
355
- const std::string func_name = spec.value (" name" , " " );
356
- const std::string description = spec.value (" description" , " " );
357
- const auto & parameters = spec.contains (" parameters" ) ? spec[" parameters" ].value (" properties" , json ({})) : json ({});
358
- const auto & required_params = spec.contains (" parameters" ) ? spec[" parameters" ].value (" required" , std::vector<std::string>()) : std::vector<std::string>();
358
+ std::string func_name = spec.value (" name" , " " );
359
+ if (func_name.find (' -' ) != std::string::npos) {
360
+ const std::string origin_func_name = func_name;
361
+ std::replace (func_name.begin (), func_name.end (), ' -' , ' _' ); // replace "-" with "_" because - is invalid in python func name
362
+ tool_name_map[func_name] = origin_func_name;
363
+ }
364
+ const std::string description = spec.contains (" description" ) ? spec[" description" ].get <std::string>() : " " ;
365
+ const auto & parameters = spec.contains (" parameters" ) && spec[" parameters" ].contains (" properties" )? spec[" parameters" ].value (" properties" , json ({})) : json ({});
366
+ const auto & required_params = spec.contains (" parameters" ) && spec[" parameters" ].contains (" properties" )? spec[" parameters" ].value (" required" , std::vector<std::string>()) : std::vector<std::string>();
359
367
360
368
std::vector<std::string> func_args;
361
369
for (auto it = parameters.begin (); it != parameters.end (); ++it) {
@@ -481,15 +489,16 @@ static json oaicompat_completion_params_parse(
481
489
llama_params[" __oaicompat" ] = true ;
482
490
483
491
std::string function_str = " " ;
492
+ json tool_name_map;
484
493
485
494
if (body.contains (" tools" ) && !body[" tools" ].empty ()) {
486
495
// function_str = default_tool_formatter(body["tool"]);
487
- function_str = rubra_format_function_call_str (body[" tools" ]);
496
+ function_str = rubra_format_function_call_str (body[" tools" ], tool_name_map );
488
497
}
489
498
// If 'tool' is not set or empty, check 'functions'
490
499
else if (body.contains (" functions" ) && !body[" functions" ].empty ()) {
491
500
// function_str = default_tool_formatter(body["functions"]);
492
- function_str = rubra_format_function_call_str (body[" functions" ]);
501
+ function_str = rubra_format_function_call_str (body[" functions" ], tool_name_map );
493
502
}
494
503
printf (" \n =============Formatting Input from OPENAI format...============\n " );
495
504
if (function_str != " " ) {
@@ -607,6 +616,7 @@ static json oaicompat_completion_params_parse(
607
616
else {
608
617
llama_params[" prompt" ] = format_chat (model, chat_template, body[" messages" ]);
609
618
}
619
+ llama_params[" tool_name_map" ] = tool_name_map;
610
620
611
621
// Map OpenAI parameters to llama.cpp parameters
612
622
//
@@ -661,9 +671,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
661
671
int num_prompt_tokens = json_value (result, " tokens_evaluated" , 0 );
662
672
std::string content = json_value (result, " content" , std::string (" " ));
663
673
664
- std::vector<json> parsed_content = parsePythonFunctionCalls (content);
674
+ std::vector<json> parsed_content = parsePythonFunctionCalls (content, request[ " tool_name_map " ] );
665
675
666
-
667
676
std::string finish_reason = " length" ;
668
677
if (stopped_word || stopped_eos) {
669
678
finish_reason = " stop" ;
@@ -732,7 +741,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
732
741
}
733
742
734
743
// return value is vector as there is one case where we might need to generate two responses
735
- static std::vector<json> format_partial_response_oaicompat (json result, const std::string & completion_id) {
744
+ static std::vector<json> format_partial_response_oaicompat (json request ,json result, const std::string & completion_id) {
736
745
if (!result.contains (" model" ) || !result.contains (" oaicompat_token_ctr" )) {
737
746
return std::vector<json>({result});
738
747
}
@@ -745,6 +754,66 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
745
754
bool stopped_limit = json_value (result, " stopped_limit" , false );
746
755
std::string content = json_value (result, " content" , std::string (" " ));
747
756
757
+ std::vector<json> parsed_content = parsePythonFunctionCalls (content, request[" tool_name_map" ]);
758
+ std::time_t t = std::time (0 );
759
+ if (!parsed_content.empty ()) {
760
+ std::vector<json> res;
761
+ json choices1 = json::array ({json{{" finish_reason" , nullptr },
762
+ {" index" , 0 },
763
+ {" delta" , json{{" role" , " assistant" }}}}});
764
+
765
+ json ret = json{
766
+ {" choices" , choices1},
767
+ {" created" , t},
768
+ {" id" , completion_id},
769
+ {" model" , modelname},
770
+ {" object" , " chat.completion.chunk" }
771
+ };
772
+ res.push_back (ret);
773
+
774
+ for (size_t i = 0 ; i < parsed_content.size (); ++i) {
775
+ const auto &pc = parsed_content[i];
776
+ // Use 'pc' and 'i' as needed
777
+ json tool_call1;
778
+ tool_call1[" id" ] = pc[" id" ];
779
+ tool_call1[" type" ] = " function" ;
780
+ tool_call1[" index" ] = i;
781
+ tool_call1[" function" ] = json{
782
+ {" name" , pc[" name" ]},
783
+ {" arguments" , " " },
784
+ };
785
+ json ret1 = json{
786
+ {" choices" , json::array ({json{{" finish_reason" , nullptr },
787
+ {" index" , 0 },
788
+ {" delta" , json{{" tool_calls" , std::vector<json>{tool_call1}}}}}})
789
+ },
790
+ {" created" , t},
791
+ {" id" , completion_id},
792
+ {" model" , modelname},
793
+ {" object" , " chat.completion.chunk" }
794
+ };
795
+ res.push_back (ret1);
796
+ json tool_call2;
797
+ tool_call2[" index" ] = i;
798
+ tool_call2[" function" ] = json{
799
+ {" name" , " " },
800
+ {" arguments" , pc[" kwargs" ].dump ()},
801
+ };
802
+ json ret2 = json{
803
+ {" choices" , json::array ({json{{" finish_reason" , nullptr },
804
+ {" index" , 0 },
805
+ {" delta" , json{{" tool_calls" , std::vector<json>{tool_call2}}}}}})
806
+ },
807
+ {" created" , t},
808
+ {" id" , completion_id},
809
+ {" model" , modelname},
810
+ {" object" , " chat.completion.chunk" }
811
+ };
812
+ res.push_back (ret2);
813
+ }
814
+ return res;
815
+ }
816
+
748
817
std::string finish_reason;
749
818
if (stopped_word || stopped_eos) {
750
819
finish_reason = " stop" ;
@@ -753,7 +822,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
753
822
finish_reason = " length" ;
754
823
}
755
824
756
- std:: time_t t = std::time ( 0 );
825
+
757
826
758
827
json choices;
759
828
0 commit comments