Skip to content

Add exclude_params feature to function_tool for hidden parameter injection #795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ for tool in agent.tools:

1. You can use any Python types as arguments to your functions, and the function can be sync or async.
2. Docstrings, if present, are used to capture descriptions and argument descriptions
3. Functions can optionally take the `context` (must be the first argument). You can also set overrides, like the name of the tool, description, which docstring style to use, etc.
3. Functions can optionally take the `context` (must be the first argument). You can also set overrides, like the name of the tool, description, which docstring style to use, and exclude specific parameters from the schema, etc.
4. You can pass the decorated functions to the list of tools.

??? note "Expand to see output"
Expand Down Expand Up @@ -284,6 +284,45 @@ async def run_my_agent() -> str:
return str(result.final_output)
```

## Excluding parameters from the schema

Sometimes, you might want to exclude certain parameters from the JSON schema that's presented to the LLM, while still making them available to your function with their default values. This can be useful for:

- Keeping implementation details hidden from the LLM
- Simplifying the tool interface presented to the model
- Maintaining backward compatibility when adding new parameters
- Supporting internal parameters that should always use default values

You can do this using the `exclude_params` parameter of the `@function_tool` decorator:

```python
from typing import Optional
from agents import function_tool, RunContextWrapper

@function_tool(exclude_params=["timestamp", "internal_id"])
def search_database(
query: str,
limit: int = 10,
timestamp: Optional[str] = None,
internal_id: Optional[str] = None
) -> str:
"""
Search the database for records matching the query.

Args:
query: The search query string
limit: Maximum number of results to return
timestamp: The timestamp to use for the search (hidden from schema)
internal_id: Internal tracking ID for telemetry (hidden from schema)
"""
# Implementation...
```

In this example:
- The LLM will only see `query` and `limit` parameters in the tool schema
- `timestamp` and `internal_id` will be automatically set to their default values when the function runs
- All excluded parameters must have default values (either `None` or a specific value)

## Handling errors in function tools

When you create a function tool via `@function_tool`, you can pass a `failure_error_function`. This is a function that provides an error response to the LLM in case the tool call crashes.
Expand Down
43 changes: 41 additions & 2 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,22 @@ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
positional_args: list[Any] = []
keyword_args: dict[str, Any] = {}
seen_var_positional = False

# Get excluded parameter defaults if they exist
excluded_param_defaults = getattr(self.params_pydantic_model, "__excluded_param_defaults__", {})

# Use enumerate() so we can skip the first parameter if it's context.
for idx, (name, param) in enumerate(self.signature.parameters.items()):
# If the function takes a RunContextWrapper and this is the first parameter, skip it.
if self.takes_context and idx == 0:
continue

value = getattr(data, name, None)
# For excluded parameters, use their default value
if name in excluded_param_defaults:
value = excluded_param_defaults[name]
else:
value = getattr(data, name, None)

if param.kind == param.VAR_POSITIONAL:
# e.g. *args: extend positional args and mark that *args is now seen
positional_args.extend(value or [])
Expand Down Expand Up @@ -190,6 +198,7 @@ def function_schema(
description_override: str | None = None,
use_docstring_info: bool = True,
strict_json_schema: bool = True,
exclude_params: list[str] | None = None,
) -> FuncSchema:
"""
Given a python function, extracts a `FuncSchema` from it, capturing the name, description,
Expand All @@ -208,6 +217,9 @@ def function_schema(
the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
recommend setting this to True, as it increases the likelihood of the LLM providing
correct JSON input.
exclude_params: If provided, these parameters will be excluded from the JSON schema
presented to the LLM. The parameters will still be available to the function with
their default values. All excluded parameters must have default values.

Returns:
A `FuncSchema` object containing the function's name, description, parameter descriptions,
Expand All @@ -231,11 +243,24 @@ def function_schema(
takes_context = False
filtered_params = []

# Store default values for excluded parameters
excluded_param_defaults = {}

if params:
first_name, first_param = params[0]
# Prefer the evaluated type hint if available
ann = type_hints.get(first_name, first_param.annotation)
if ann != inspect._empty:

# Check if this parameter should be excluded
if exclude_params and first_name in exclude_params:
# Ensure the parameter has a default value
if first_param.default is inspect._empty:
raise UserError(
f"Parameter '{first_name}' specified in exclude_params must have a default value"
)
# Store default value
excluded_param_defaults[first_name] = first_param.default
elif ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper:
takes_context = True # Mark that the function takes context
Expand All @@ -246,6 +271,17 @@ def function_schema(

# For parameters other than the first, raise error if any use RunContextWrapper.
for name, param in params[1:]:
# Check if this parameter should be excluded
if exclude_params and name in exclude_params:
# Ensure the parameter has a default value
if param.default is inspect._empty:
raise UserError(
f"Parameter '{name}' specified in exclude_params must have a default value"
)
# Store default value
excluded_param_defaults[name] = param.default
continue

ann = type_hints.get(name, param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
Expand Down Expand Up @@ -326,6 +362,9 @@ def function_schema(

# 3. Dynamically build a Pydantic model
dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields)

# Store excluded parameter defaults in the model for later use
setattr(dynamic_model, "__excluded_param_defaults__", excluded_param_defaults)

# 4. Build JSON schema from that model
json_schema = dynamic_model.model_json_schema()
Expand Down
21 changes: 21 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
exclude_params: list[str] | None = None,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...
Expand All @@ -276,6 +277,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
exclude_params: list[str] | None = None,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...
Expand All @@ -290,6 +292,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True,
exclude_params: list[str] | None = None,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
"""
Decorator to create a FunctionTool from a function. By default, we will:
Expand Down Expand Up @@ -318,16 +321,34 @@ def function_tool(
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
value, it will be optional, additional properties are allowed, etc. See here for more:
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
exclude_params: If provided, these parameters will be excluded from the JSON schema
presented to the LLM. The parameters will still be available to the function with
their default values. All excluded parameters must have default values.
"""

def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
# Check that all excluded parameters have default values
if exclude_params:
sig = inspect.signature(the_func)
for param_name in exclude_params:
if param_name not in sig.parameters:
raise UserError(
f"Parameter '{param_name}' specified in exclude_params doesn't exist in function {the_func.__name__}"
)
param = sig.parameters[param_name]
if param.default is inspect._empty:
raise UserError(
f"Parameter '{param_name}' specified in exclude_params must have a default value"
)

schema = function_schema(
func=the_func,
name_override=name_override,
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
strict_json_schema=strict_mode,
exclude_params=exclude_params,
)

async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
Expand Down
64 changes: 64 additions & 0 deletions tests/test_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,67 @@ def func_with_mapping(test_one: Mapping[str, int]) -> str:

with pytest.raises(UserError):
function_schema(func_with_mapping)


def function_with_optional_params(a: int, b: int = 5, c: str = "default"):
"""Function with multiple optional parameters."""
return f"{a}-{b}-{c}"


def test_exclude_params_feature():
"""Test the exclude_params feature works correctly."""
# Test excluding a single optional parameter
func_schema = function_schema(
function_with_optional_params,
exclude_params=["c"],
)

# Verify 'c' is not in the schema properties
assert "c" not in func_schema.params_json_schema.get("properties", {})

# Verify the excluded parameter defaults are stored
excluded_defaults = getattr(func_schema.params_pydantic_model, "__excluded_param_defaults__", {})
assert "c" in excluded_defaults
assert excluded_defaults["c"] == "default"

# Test function still works correctly with excluded parameter
valid_input = {"a": 10, "b": 20}
parsed = func_schema.params_pydantic_model(**valid_input)
args, kwargs_dict = func_schema.to_call_args(parsed)
result = function_with_optional_params(*args, **kwargs_dict)
assert result == "10-20-default" # 'c' should use its default value

# Test excluding multiple parameters
func_schema_multi = function_schema(
function_with_optional_params,
exclude_params=["b", "c"],
)

# Verify both 'b' and 'c' are not in the schema properties
assert "b" not in func_schema_multi.params_json_schema.get("properties", {})
assert "c" not in func_schema_multi.params_json_schema.get("properties", {})

# Test function still works correctly with multiple excluded parameters
valid_input = {"a": 10}
parsed = func_schema_multi.params_pydantic_model(**valid_input)
args, kwargs_dict = func_schema_multi.to_call_args(parsed)
result = function_with_optional_params(*args, **kwargs_dict)
assert result == "10-5-default" # 'b' and 'c' should use their default values


def function_with_required_param(a: int, b: str):
"""Function with required parameters only."""
return f"{a}-{b}"


def test_exclude_params_requires_default_value():
"""Test that excluding a parameter without a default value raises an error."""
# Attempt to exclude a parameter without a default value
with pytest.raises(UserError) as excinfo:
function_schema(
function_with_required_param,
exclude_params=["b"],
)

# Check the error message
assert "must have a default value" in str(excinfo.value)
36 changes: 36 additions & 0 deletions tests/test_function_tool_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,39 @@ async def test_extract_descriptions_from_docstring():
"additionalProperties": False,
}
)


@function_tool(exclude_params=["timestamp"])
def function_with_excluded_param(
city: str, country: str = "US", timestamp: Optional[str] = None
) -> str:
"""Get the weather for a given city with timestamp.

Args:
city: The city to get the weather for.
country: The country the city is in.
timestamp: The timestamp for the weather data (hidden from schema).
"""
time_str = f" at {timestamp}" if timestamp else ""
return f"The weather in {city}, {country}{time_str} is sunny."


@pytest.mark.asyncio
async def test_exclude_params_from_schema():
"""Test that excluded parameters are not included in the schema."""
tool = function_with_excluded_param

# Check that the parameter is not in the schema
assert "timestamp" not in tool.params_json_schema.get("properties", {})

# Check that only non-excluded parameters are required
assert set(tool.params_json_schema.get("required", [])) == {"city"}

# Test function still works with excluded parameter
input_data = {"city": "Seattle", "country": "US"}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "The weather in Seattle, US is sunny."

# Test function works when we supply a default excluded parameter value in the code
function_result = function_with_excluded_param("Seattle", "US", "2023-05-29T12:00:00Z")
assert function_result == "The weather in Seattle, US at 2023-05-29T12:00:00Z is sunny."