Skip to content

Support tools #1587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Feb 28, 2024
Merged

Support tools #1587

merged 19 commits into from
Feb 28, 2024

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Feb 22, 2024

This work in progress PR begins to add support for tools. Tools relies on grammar support and still has some unsolved challenges. Opening the PR for visibility and feedback

@drbh drbh marked this pull request as ready for review February 26, 2024 14:30
@drbh
Copy link
Collaborator Author

drbh commented Feb 26, 2024

This PR adds basic support for tools/functions. The api follows the same format as openai and can be used with the openai python client library

note current differences between TGI functions and openai:

  • TGI's tool_choice="auto" will ALWAYS choose a function; openai's can decide to choose none or ask a followup question.

example use (with TGI running locally)

from openai import OpenAI

# Initialize the client, pointing it to one of the available models
client = OpenAI(
    base_url="http://localhost:3000/v1",
    api_key="_",
)


tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location.",
                    },
                },
                "required": ["location", "format"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "get_n_day_weather_forecast",
            "description": "Get an N-day weather forecast",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location.",
                    },
                    "num_days": {
                        "type": "integer",
                        "description": "The number of days to forecast",
                    },
                },
                "required": ["location", "format", "num_days"],
            },
        },
    }
]

Use one of the provided tools

chat_completion = client.chat.completions.create(
    model="tgi",
    messages=[
        {
            "role": "system",
            "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
        },
        {
            "role": "user",
            "content": "What's the weather like the next 3 days in San Francisco, CA?",
        },
    ],
    tools=tools,
    tool_choice="auto",  # tool selected by model
    max_tokens=500,
)


called = chat_completion.choices[0].message.tool_calls
print(called)
# {
#     "id": 0,
#     "type": "function",
#     "function": {
#         "description": None,
#         "name": "tools",
#         "parameters": {
#             "format": "celsius",
#             "location": "San Francisco, CA",
#             "num_days": 3,
#         },
#     },
# }

Use specific tool

chat_completion = client.chat.completions.create(
    model="tgi",
    messages=[
        {
            "role": "system",
            "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
        },
        {
            "role": "user",
            "content": "What's the weather like the next 3 days in San Francisco, CA?",
        },
    ],
    tools=tools,
    tool_choice="get_current_weather",  # tool selected by caller
    max_tokens=500,
)

called = chat_completion.choices[0].message.tool_calls
print(called)
# {
#     "id": 0,
#     "type": "function",
#     "function": {
#         "description": None,
#         "name": "tools",
#         "parameters": {"format": "celsius", "location": "San Francisco, CA"},
#     },
# }

@doc doc mentioned this pull request Feb 26, 2024
@Narsil
Copy link
Collaborator

Narsil commented Feb 27, 2024

Looking good overall.

We definitely need some docs !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Narsil Narsil merged commit 9b6db5f into main Feb 28, 2024
@Narsil Narsil deleted the support-tools branch February 28, 2024 10:10
@RonanKMcGovern
Copy link

RonanKMcGovern commented Mar 5, 2024

Thanks @drbh ! A few thoughts/points:

  1. Why does 'auto' always results in a function call? Is it because of using outlines or similar for controlling generation (I don't obviously see that in the code, at least not outlines)? Or some other reason? (I'm trying to grasp what it will take to make it truly auto (apart from the model being well trained).

  2. I note you have 'model="tgi" in the example, I suppose that would be updated to the actual model being inferenced with tgi? Perhaps something to clarify in the docs. I tried it both ways and it seemed to work, so the parameter is maybe redundant?

  3. Using an openai endpoint seems to return a list, whereas tgi currently returns (at least from the openai style endpoint) a dictionary. This makes parsing different:

ChatCompletion(id='', choices=[Choice(finish_reason='eos_token', index=0, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls={'id': 0, 'type': 'function', 'function': {'description': None, 'name': 'tools', 'parameters': {'location': 'London'}}}), logprobs=None)], created=1709655533, model='Trelis/openchat_3.5-function-calling-v3', object='text_completion', system_fingerprint='1.4.3-sha-7dbaf9e', usage=CompletionUsage(completion_tokens=10, prompt_tokens=164, total_tokens=174))

with this:

ChatCompletion(id='chatcmpl-8zRomovpipeQ2EfecKtoVE8yXarpy', choices=[Choice(finish_reason='tool_calls', index=0, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_RSVG6xXRXEXusk5zJEYaONfv', function=Function(arguments='{\n  "location": "London"\n}', name='get_current_weather'), type='function')]), logprobs=None)], created=1709655632, model='gpt-3.5-turbo-0613', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=16, prompt_tokens=226, total_tokens=242))
  1. The chat template is still loaded from the base model, correct? But the function calling syntax is injected on top of that?

  2. To the extent it can be made very clear in the docs exactly how the functions/tools are being formatted - specifically, what helper text is being added - that will help model developers ensure support. Typically (and it makes sense for closed apis), it's very hard to track exactly how tools are being formatted. Right now I see there is a flag for a custom prompt before tools, but it's not entirely clear what the defaults are.

@puppetm4st3r
Copy link

puppetm4st3r commented Mar 6, 2024

Hi, nice Jobs, I find a inconsistent with the output of the selected tool accord to the open ai api specification, I create the issue in #1624
and I solved (I think) but its my first time with Rust and I couldn't get the code to work on my local virtual env (instructions from readme.md didn't work), so I modified the code and ran it across the docker build and the container worked according to the Open AI specification, could you guide me (@drbh ) how to proceed? the regular pipeline to be able to execute locally and the tests to be able to do the PR as appropriate way.

The PR only modify the server.rs and the lib.rs, and basically the improvement is on the grammar generated for the tool selection inference, the grammar of the PR complains the json schema standar and allows the llm to indicate the name of the selected function, also the struct og the grammar fits the open ai schema

now the ouput for:

from openai import OpenAI
tools = [
      {
          "type": "function",
          "function": {
              "name": "get_current_weather",
              "description": "Get the current weather",
              "parameters": {
                  "type": "object",
                  "properties": {
                      "location": {
                          "type": "string",
                          "description": "The city and state, e.g. San Francisco, CA",
                      },
                      "format": {
                          "type": "string",
                          "enum": ["celsius", "fahrenheit"],
                          "description": "The temperature unit to use. Infer this from the users location.",
                      },
                  },
                  "required": ["location", "format"],
              },
          },
      },
      {
          "type": "function",
          "function": {
              "name": "get_n_day_weather_forecast",
              "description": "Get an N-day weather forecast",
              "parameters": {
                  "type": "object",
                  "properties": {
                      "location": {
                          "type": "string",
                          "description": "The city and state, e.g. San Francisco, CA",
                      },
                      "format": {
                          "type": "string",
                          "enum": ["celsius", "fahrenheit"],
                          "description": "The temperature unit to use. Infer this from the users location.",
                      },
                      "num_days": {
                          "type": "integer",
                          "description": "The number of days to forecast",
                      },
                  },
                  "required": ["location", "format", "num_days"],
              },
          },
      }
  ]
# Initialize the client, pointing it to one of the available models
client = OpenAI(
    base_url="http://llm_server:3000/v1",
    api_key="_"
)

# NOTE: tools defined above and removed for brevity

chat_completion = client.chat.completions.create(
    model="tgi",
    messages=[
        {
            "role": "system",
            "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
        },
        {
            "role": "user",
            "content": "What's the weather like the next 3 days in San Francisco, CA?",
        },
    ],
    tools=tools,
    tool_choice="auto",  # tool selected by model
    max_tokens=500,
)


called = chat_completion.choices[0].message.tool_calls
print(called)

code output:
[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{"format":"fahrenheit","location":"San Francisco, CA","num_days":3}', name='get_n_day_weather_forecast'), type='function')]

open ai specs from docs outputs:
[ChatCompletionMessageToolCall(id='call_ujD1NwPxzeOSCbgw2NOabOin', function=Function(arguments='{\n "location": "Glasgow, Scotland",\n "format": "celsius",\n "num_days": 5\n}', name='get_n_day_weather_forecast'), type='function')]

llm raw output from TGI debug tracing:

{'id': 0, 'type': 'function', 'function': {'name': 'get_n_day_weather_forecast', 'arguments': {'format': 'celsius', 'location': 'San Francisco, CA', 'num_days': 3}}}

I did not touch the streaming methods and the objects ChatCompletionChunk and ChatCompletionDelta, my rust uderstanding is quiet pretty basic.

@jphme
Copy link

jphme commented Mar 14, 2024

@drbh @Narsil Is there any documentation (or examples) on

  1. how the function definitions are serialized (as strings) for the LLM prompt and
  2. how the function calls from the LLM that are parsed do look like (is it the format that @puppetm4st3r posted in the last post as "llm raw output"?

This would be very helpful for everyone training models to increase reliability of Tool/calls and structured output. Unfortunately there is no single standard for all large inference libraries and almost all models use different formats (requiring custom wrappers) as well - it would be a huge win for the whole OSS LLM community if you add a detailed documentation of formats which hopefully could then also be adopted by other inference libs and guarantee OpenAI compatibility....

many thanks!

@RonanKMcGovern
Copy link

+1 I couldn't have put this better myself

@puppetm4st3r
Copy link

puppetm4st3r commented Mar 14, 2024

have solved but these days I have not had time to investigate how to create the tests and the pull request, I am too oversold with my work, but if someone can support us, I can gladly deliver the changes made in the few files that were necessary to modify so that some member can be generated the pull (I've never made a contribution with Git, it's my first time, and unfortunately I haven't had the space for the learning curve)

kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this pull request Apr 29, 2024
This work in progress PR begins to add support for tools. Tools relies
on grammar support and still has some unsolved challenges. Opening the
PR for visibility and feedback
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants