10
10
#include < vector>
11
11
#include < sstream>
12
12
#include < random>
13
+ #include < unordered_map>
14
+ #include < algorithm>
13
15
14
16
#define DEFAULT_OAICOMPAT_MODEL " gpt-3.5-turbo-0613"
15
17
@@ -348,8 +350,9 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
348
350
//
349
351
350
352
351
- static std::string rubra_format_function_call_str (const std::vector<json> & functions) {
353
+ static std::string rubra_format_function_call_str (const std::vector<json> & functions, json & tool_name_map ) {
352
354
std::string final_str = " You have access to the following tools:\n " ;
355
+ printf (" rubra_format_function_call_str parsing...\n " );
353
356
json type_mapping = {
354
357
{" string" , " str" },
355
358
{" integer" , " int" },
@@ -363,10 +366,15 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
363
366
std::vector<std::string> function_definitions;
364
367
for (const auto & function : functions) {
365
368
const auto &spec = function.contains (" function" ) ? function[" function" ] : function;
366
- const std::string func_name = spec.value (" name" , " " );
367
- const std::string description = spec.value (" description" , " " );
368
- const auto & parameters = spec.contains (" parameters" ) ? spec[" parameters" ].value (" properties" , json ({})) : json ({});
369
- const auto & required_params = spec.contains (" parameters" ) ? spec[" parameters" ].value (" required" , std::vector<std::string>()) : std::vector<std::string>();
369
+ std::string func_name = spec.value (" name" , " " );
370
+ if (func_name.find (' -' ) != std::string::npos) {
371
+ const std::string origin_func_name = func_name;
372
+ std::replace (func_name.begin (), func_name.end (), ' -' , ' _' ); // replace "-" with "_" because - is invalid in python func name
373
+ tool_name_map[func_name] = origin_func_name;
374
+ }
375
+ const std::string description = spec.contains (" description" ) ? spec[" description" ].get <std::string>() : " " ;
376
+ const auto & parameters = spec.contains (" parameters" ) && spec[" parameters" ].contains (" properties" )? spec[" parameters" ].value (" properties" , json ({})) : json ({});
377
+ const auto & required_params = spec.contains (" parameters" ) && spec[" parameters" ].contains (" properties" )? spec[" parameters" ].value (" required" , std::vector<std::string>()) : std::vector<std::string>();
370
378
371
379
std::vector<std::string> func_args;
372
380
for (auto it = parameters.begin (); it != parameters.end (); ++it) {
@@ -492,15 +500,16 @@ static json oaicompat_completion_params_parse(
492
500
llama_params[" __oaicompat" ] = true ;
493
501
494
502
std::string function_str = " " ;
503
+ json tool_name_map;
495
504
496
505
if (body.contains (" tools" ) && !body[" tools" ].empty ()) {
497
506
// function_str = default_tool_formatter(body["tool"]);
498
- function_str = rubra_format_function_call_str (body[" tools" ]);
507
+ function_str = rubra_format_function_call_str (body[" tools" ], tool_name_map );
499
508
}
500
509
// If 'tool' is not set or empty, check 'functions'
501
510
else if (body.contains (" functions" ) && !body[" functions" ].empty ()) {
502
511
// function_str = default_tool_formatter(body["functions"]);
503
- function_str = rubra_format_function_call_str (body[" functions" ]);
512
+ function_str = rubra_format_function_call_str (body[" functions" ], tool_name_map );
504
513
}
505
514
printf (" \n =============Formatting Input from OPENAI format...============\n " );
506
515
if (function_str != " " ) {
@@ -618,6 +627,7 @@ static json oaicompat_completion_params_parse(
618
627
else {
619
628
llama_params[" prompt" ] = format_chat (model, chat_template, body[" messages" ]);
620
629
}
630
+ llama_params[" tool_name_map" ] = tool_name_map;
621
631
622
632
// Map OpenAI parameters to llama.cpp parameters
623
633
//
@@ -705,9 +715,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
705
715
int num_prompt_tokens = json_value (result, " tokens_evaluated" , 0 );
706
716
std::string content = json_value (result, " content" , std::string (" " ));
707
717
708
- std::vector<json> parsed_content = parsePythonFunctionCalls (content);
718
+ std::vector<json> parsed_content = parsePythonFunctionCalls (content, request[ " tool_name_map " ] );
709
719
710
-
711
720
std::string finish_reason = " length" ;
712
721
if (stopped_word || stopped_eos) {
713
722
finish_reason = " stop" ;
@@ -776,7 +785,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
776
785
}
777
786
778
787
// return value is vector as there is one case where we might need to generate two responses
779
- static std::vector<json> format_partial_response_oaicompat (json result, const std::string & completion_id) {
788
+ static std::vector<json> format_partial_response_oaicompat (json request ,json result, const std::string & completion_id) {
780
789
if (!result.contains (" model" ) || !result.contains (" oaicompat_token_ctr" )) {
781
790
return std::vector<json>({result});
782
791
}
@@ -789,6 +798,66 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
789
798
bool stopped_limit = json_value (result, " stopped_limit" , false );
790
799
std::string content = json_value (result, " content" , std::string (" " ));
791
800
801
+ std::vector<json> parsed_content = parsePythonFunctionCalls (content, request[" tool_name_map" ]);
802
+ std::time_t t = std::time (0 );
803
+ if (!parsed_content.empty ()) {
804
+ std::vector<json> res;
805
+ json choices1 = json::array ({json{{" finish_reason" , nullptr },
806
+ {" index" , 0 },
807
+ {" delta" , json{{" role" , " assistant" }}}}});
808
+
809
+ json ret = json{
810
+ {" choices" , choices1},
811
+ {" created" , t},
812
+ {" id" , completion_id},
813
+ {" model" , modelname},
814
+ {" object" , " chat.completion.chunk" }
815
+ };
816
+ res.push_back (ret);
817
+
818
+ for (size_t i = 0 ; i < parsed_content.size (); ++i) {
819
+ const auto &pc = parsed_content[i];
820
+ // Use 'pc' and 'i' as needed
821
+ json tool_call1;
822
+ tool_call1[" id" ] = pc[" id" ];
823
+ tool_call1[" type" ] = " function" ;
824
+ tool_call1[" index" ] = i;
825
+ tool_call1[" function" ] = json{
826
+ {" name" , pc[" name" ]},
827
+ {" arguments" , " " },
828
+ };
829
+ json ret1 = json{
830
+ {" choices" , json::array ({json{{" finish_reason" , nullptr },
831
+ {" index" , 0 },
832
+ {" delta" , json{{" tool_calls" , std::vector<json>{tool_call1}}}}}})
833
+ },
834
+ {" created" , t},
835
+ {" id" , completion_id},
836
+ {" model" , modelname},
837
+ {" object" , " chat.completion.chunk" }
838
+ };
839
+ res.push_back (ret1);
840
+ json tool_call2;
841
+ tool_call2[" index" ] = i;
842
+ tool_call2[" function" ] = json{
843
+ {" name" , " " },
844
+ {" arguments" , pc[" kwargs" ].dump ()},
845
+ };
846
+ json ret2 = json{
847
+ {" choices" , json::array ({json{{" finish_reason" , nullptr },
848
+ {" index" , 0 },
849
+ {" delta" , json{{" tool_calls" , std::vector<json>{tool_call2}}}}}})
850
+ },
851
+ {" created" , t},
852
+ {" id" , completion_id},
853
+ {" model" , modelname},
854
+ {" object" , " chat.completion.chunk" }
855
+ };
856
+ res.push_back (ret2);
857
+ }
858
+ return res;
859
+ }
860
+
792
861
std::string finish_reason;
793
862
if (stopped_word || stopped_eos) {
794
863
finish_reason = " stop" ;
@@ -797,7 +866,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
797
866
finish_reason = " length" ;
798
867
}
799
868
800
- std:: time_t t = std::time ( 0 );
869
+
801
870
802
871
json choices;
803
872
0 commit comments