Skip to content

Add is_enabled to FunctionTool #808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import dataclasses
import inspect
from collections.abc import Awaitable
Expand All @@ -17,7 +18,7 @@
from .model_settings import ModelSettings
from .models.interface import Model
from .run_context import RunContextWrapper, TContext
from .tool import FunctionToolResult, Tool, function_tool
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
from .util import _transforms
from .util._types import MaybeAwaitable

Expand Down Expand Up @@ -246,7 +247,22 @@ async def get_mcp_tools(self) -> list[Tool]:
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)

async def get_all_tools(self) -> list[Tool]:
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
"""All agent tools, including MCP tools and function tools."""
mcp_tools = await self.get_mcp_tools()
return mcp_tools + self.tools

async def _check_tool_enabled(tool: Tool) -> bool:
if not isinstance(tool, FunctionTool):
return True

attr = tool.is_enabled
if isinstance(attr, bool):
return attr
res = attr(run_context, self)
if inspect.isawaitable(res):
return bool(await res)
return bool(res)

results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
return [*mcp_tools, *enabled]
10 changes: 6 additions & 4 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ async def run(

try:
while True:
all_tools = await cls._get_all_tools(current_agent)
all_tools = await cls._get_all_tools(current_agent, context_wrapper)

# Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends.
Expand Down Expand Up @@ -525,7 +525,7 @@ async def _run_streamed_impl(
if streamed_result.is_complete:
break

all_tools = await cls._get_all_tools(current_agent)
all_tools = await cls._get_all_tools(current_agent, context_wrapper)

# Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends.
Expand Down Expand Up @@ -980,8 +980,10 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
return handoffs

@classmethod
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]:
return await agent.get_all_tools()
async def _get_all_tools(
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
) -> list[Tool]:
return await agent.get_all_tools(context_wrapper)

@classmethod
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
Expand Down
17 changes: 16 additions & 1 deletion src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import Any, Callable, Literal, Union, overload
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload

from openai.types.responses.file_search_tool_param import Filters, RankingOptions
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
Expand All @@ -24,6 +24,9 @@
from .util import _error_tracing
from .util._types import MaybeAwaitable

if TYPE_CHECKING:
from .agent import Agent

ToolParams = ParamSpec("ToolParams")

ToolFunctionWithoutContext = Callable[ToolParams, Any]
Expand Down Expand Up @@ -74,6 +77,11 @@ class FunctionTool:
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""

is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
"""Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
based on your context/state."""


@dataclass
class FileSearchTool:
Expand Down Expand Up @@ -262,6 +270,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...
Expand All @@ -276,6 +285,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...
Expand All @@ -290,6 +300,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
"""
Decorator to create a FunctionTool from a function. By default, we will:
Expand Down Expand Up @@ -318,6 +329,9 @@ def function_tool(
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
value, it will be optional, additional properties are allowed, etc. See here for more:
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
context and agent and returns whether the tool is enabled. Disabled tools are hidden
from the LLM at runtime.
"""

def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
Expand Down Expand Up @@ -407,6 +421,7 @@ async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
params_json_schema=schema.params_json_schema,
on_invoke_tool=_on_invoke_tool,
strict_json_schema=strict_mode,
is_enabled=is_enabled,
)

# If func is actually a callable, we were used as @function_tool with no parentheses
Expand Down
43 changes: 42 additions & 1 deletion tests/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel
from typing_extensions import TypedDict

from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
from agents.tool import default_tool_error_function


Expand Down Expand Up @@ -255,3 +255,44 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->

result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}')
assert result == "error_ValueError"


class BoolCtx(BaseModel):
enable_tools: bool


@pytest.mark.asyncio
async def test_is_enabled_bool_and_callable():
@function_tool(is_enabled=False)
def disabled_tool():
return "nope"

async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool:
return ctx.context.enable_tools

@function_tool(is_enabled=cond_enabled)
def another_tool():
return "hi"

async def third_tool_on_invoke_tool(ctx: RunContextWrapper[Any], args: str) -> str:
return "third"

third_tool = FunctionTool(
name="third_tool",
description="third tool",
on_invoke_tool=third_tool_on_invoke_tool,
is_enabled=lambda ctx, agent: ctx.context.enable_tools,
params_json_schema={},
)

agent = Agent(name="t", tools=[disabled_tool, another_tool, third_tool])
context_1 = RunContextWrapper(BoolCtx(enable_tools=False))
context_2 = RunContextWrapper(BoolCtx(enable_tools=True))

tools_with_ctx = await agent.get_all_tools(context_1)
assert tools_with_ctx == []

tools_with_ctx = await agent.get_all_tools(context_2)
assert len(tools_with_ctx) == 2
assert tools_with_ctx[0].name == "another_tool"
assert tools_with_ctx[1].name == "third_tool"
2 changes: 1 addition & 1 deletion tests/test_run_step_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ async def get_execute_result(

processed_response = RunImpl.process_model_response(
agent=agent,
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(context_wrapper or RunContextWrapper(None)),
response=response,
output_schema=output_schema,
handoffs=handoffs,
Expand Down
32 changes: 18 additions & 14 deletions tests/test_run_step_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
)


def _dummy_ctx() -> RunContextWrapper[None]:
return RunContextWrapper(context=None)


def test_empty_response():
agent = Agent(name="test")
response = ModelResponse(
Expand Down Expand Up @@ -83,7 +87,7 @@ async def test_single_tool_call():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
assert not result.handoffs
assert result.functions and len(result.functions) == 1
Expand Down Expand Up @@ -111,7 +115,7 @@ async def test_missing_tool_call_raises_error():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)


Expand Down Expand Up @@ -140,7 +144,7 @@ async def test_multiple_tool_calls():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
assert not result.handoffs
assert result.functions and len(result.functions) == 2
Expand Down Expand Up @@ -169,7 +173,7 @@ async def test_handoffs_parsed_correctly():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent_3.get_all_tools(),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert not result.handoffs, "Shouldn't have a handoff here"

Expand All @@ -183,7 +187,7 @@ async def test_handoffs_parsed_correctly():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert len(result.handoffs) == 1, "Should have a handoff here"
handoff = result.handoffs[0]
Expand Down Expand Up @@ -213,7 +217,7 @@ async def test_missing_handoff_fails():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)


Expand All @@ -236,7 +240,7 @@ async def test_multiple_handoffs_doesnt_error():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert len(result.handoffs) == 2, "Should have multiple handoffs here"

Expand All @@ -262,7 +266,7 @@ async def test_final_output_parsed_correctly():
response=response,
output_schema=Runner._get_output_schema(agent),
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)


Expand All @@ -288,7 +292,7 @@ async def test_file_search_tool_call_parsed_correctly():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
# The final item should be a ToolCallItem for the file search call
assert any(
Expand All @@ -313,7 +317,7 @@ async def test_function_web_search_tool_call_parsed_correctly():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
assert any(
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
Expand All @@ -340,7 +344,7 @@ async def test_reasoning_item_parsed_correctly():
response=response,
output_schema=None,
handoffs=[],
all_tools=await Agent(name="test").get_all_tools(),
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
)
assert any(
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
Expand Down Expand Up @@ -409,7 +413,7 @@ async def test_computer_tool_call_without_computer_tool_raises_error():
response=response,
output_schema=None,
handoffs=[],
all_tools=await Agent(name="test").get_all_tools(),
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
)


Expand Down Expand Up @@ -437,7 +441,7 @@ async def test_computer_tool_call_with_computer_tool_parsed_correctly():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
assert any(
isinstance(item, ToolCallItem) and item.raw_item is computer_call
Expand Down Expand Up @@ -468,7 +472,7 @@ async def test_tool_and_handoff_parsed_correctly():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert result.functions and len(result.functions) == 1
assert len(result.handoffs) == 1, "Should have a handoff here"
Expand Down