Skip to content

Commit 682f49c

Browse files
committed
Don't cache agent tools during a run
### Summary: Towards #767. We were caching the list of tools for an agent, so if you did `agent.tools.append(...)` from a tool call, the next call to the model wouldn't include the new tool. THis is a bug. ### Test Plan: Unit tests. Note that now MCP tools are listed each time the agent runs (users can still cache the `list_tools` however).
1 parent 775d3e2 commit 682f49c

File tree

11 files changed

+143
-48
lines changed

11 files changed

+143
-48
lines changed

src/agents/exceptions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
@dataclass
1616
class RunErrorDetails:
1717
"""Data collected from an agent run when an exception occurs."""
18+
1819
input: str | list[TResponseInputItem]
1920
new_items: list[RunItem]
2021
raw_responses: list[ModelResponse]
@@ -29,6 +30,7 @@ def __str__(self) -> str:
2930

3031
class AgentsException(Exception):
3132
"""Base class for all exceptions in the Agents SDK."""
33+
3234
run_data: RunErrorDetails | None
3335

3436
def __init__(self, *args: object) -> None:

src/agents/extensions/models/litellm_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,14 @@ async def get_response(
110110
input_tokens_details=InputTokensDetails(
111111
cached_tokens=getattr(
112112
response_usage.prompt_tokens_details, "cached_tokens", 0
113-
) or 0
113+
)
114+
or 0
114115
),
115116
output_tokens_details=OutputTokensDetails(
116117
reasoning_tokens=getattr(
117118
response_usage.completion_tokens_details, "reasoning_tokens", 0
118-
) or 0
119+
)
120+
or 0
119121
),
120122
)
121123
if response.usage

src/agents/mcp/server.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def create_streams(
8888
tuple[
8989
MemoryObjectReceiveStream[SessionMessage | Exception],
9090
MemoryObjectSendStream[SessionMessage],
91-
GetSessionIdCallback | None
91+
GetSessionIdCallback | None,
9292
]
9393
]:
9494
"""Create the streams for the server."""
@@ -243,7 +243,7 @@ def create_streams(
243243
tuple[
244244
MemoryObjectReceiveStream[SessionMessage | Exception],
245245
MemoryObjectSendStream[SessionMessage],
246-
GetSessionIdCallback | None
246+
GetSessionIdCallback | None,
247247
]
248248
]:
249249
"""Create the streams for the server."""
@@ -314,7 +314,7 @@ def create_streams(
314314
tuple[
315315
MemoryObjectReceiveStream[SessionMessage | Exception],
316316
MemoryObjectSendStream[SessionMessage],
317-
GetSessionIdCallback | None
317+
GetSessionIdCallback | None,
318318
]
319319
]:
320320
"""Create the streams for the server."""
@@ -394,7 +394,7 @@ def create_streams(
394394
tuple[
395395
MemoryObjectReceiveStream[SessionMessage | Exception],
396396
MemoryObjectSendStream[SessionMessage],
397-
GetSessionIdCallback | None
397+
GetSessionIdCallback | None,
398398
]
399399
]:
400400
"""Create the streams for the server."""
@@ -403,7 +403,7 @@ def create_streams(
403403
headers=self.params.get("headers", None),
404404
timeout=self.params.get("timeout", timedelta(seconds=30)),
405405
sse_read_timeout=self.params.get("sse_read_timeout", timedelta(seconds=60 * 5)),
406-
terminate_on_close=self.params.get("terminate_on_close", True)
406+
terminate_on_close=self.params.get("terminate_on_close", True),
407407
)
408408

409409
@property

src/agents/result.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,3 @@ def _cleanup_tasks(self):
274274

275275
def __str__(self) -> str:
276276
return pretty_print_run_result_streaming(self)
277-

src/agents/run.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from __future__ import annotations
32

43
import asyncio
@@ -182,6 +181,8 @@ async def run(
182181

183182
try:
184183
while True:
184+
all_tools = await cls._get_all_tools(current_agent)
185+
185186
# Start an agent span if we don't have one. This span is ended if the current
186187
# agent changes, or if the agent loop ends.
187188
if current_span is None:
@@ -197,8 +198,6 @@ async def run(
197198
output_type=output_type_name,
198199
)
199200
current_span.start(mark_as_current=True)
200-
201-
all_tools = await cls._get_all_tools(current_agent)
202201
current_span.span_data.tools = [t.name for t in all_tools]
203202

204203
current_turn += 1
@@ -210,9 +209,7 @@ async def run(
210209
data={"max_turns": max_turns},
211210
),
212211
)
213-
raise MaxTurnsExceeded(
214-
f"Max turns ({max_turns}) exceeded"
215-
)
212+
raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded")
216213

217214
logger.debug(
218215
f"Running agent {current_agent.name} (turn {current_turn})",
@@ -295,7 +292,7 @@ async def run(
295292
last_agent=current_agent,
296293
context_wrapper=context_wrapper,
297294
input_guardrail_results=input_guardrail_results,
298-
output_guardrail_results=[]
295+
output_guardrail_results=[],
299296
)
300297
raise
301298
finally:
@@ -528,6 +525,8 @@ async def _run_streamed_impl(
528525
if streamed_result.is_complete:
529526
break
530527

528+
all_tools = await cls._get_all_tools(current_agent)
529+
531530
# Start an agent span if we don't have one. This span is ended if the current
532531
# agent changes, or if the agent loop ends.
533532
if current_span is None:
@@ -543,8 +542,6 @@ async def _run_streamed_impl(
543542
output_type=output_type_name,
544543
)
545544
current_span.start(mark_as_current=True)
546-
547-
all_tools = await cls._get_all_tools(current_agent)
548545
tool_names = [t.name for t in all_tools]
549546
current_span.span_data.tools = tool_names
550547
current_turn += 1

src/agents/voice/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"]
1818
"""Exportable type for the TTSModelSettings voice enum"""
1919

20+
2021
@dataclass
2122
class TTSModelSettings:
2223
"""Settings for a TTS model."""
24+
2325
voice: TTSVoice | None = None
2426
"""
2527
The voice to use for the TTS model. If not provided, the default voice for the respective model

tests/mcp/test_mcp_tracing.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ async def test_mcp_tracing():
4444
{
4545
"workflow_name": "Agent workflow",
4646
"children": [
47+
{
48+
"type": "mcp_tools",
49+
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
50+
},
4751
{
4852
"type": "agent",
4953
"data": {
@@ -53,10 +57,6 @@ async def test_mcp_tracing():
5357
"output_type": "str",
5458
},
5559
"children": [
56-
{
57-
"type": "mcp_tools",
58-
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
59-
},
6060
{
6161
"type": "function",
6262
"data": {
@@ -66,8 +66,12 @@ async def test_mcp_tracing():
6666
"mcp_data": {"server": "fake_mcp_server"},
6767
},
6868
},
69+
{
70+
"type": "mcp_tools",
71+
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
72+
},
6973
],
70-
}
74+
},
7175
],
7276
}
7377
]
@@ -100,6 +104,13 @@ async def test_mcp_tracing():
100104
{
101105
"workflow_name": "Agent workflow",
102106
"children": [
107+
{
108+
"type": "mcp_tools",
109+
"data": {
110+
"server": "fake_mcp_server",
111+
"result": ["test_tool_1", "test_tool_2"],
112+
},
113+
},
103114
{
104115
"type": "agent",
105116
"data": {
@@ -109,13 +120,6 @@ async def test_mcp_tracing():
109120
"output_type": "str",
110121
},
111122
"children": [
112-
{
113-
"type": "mcp_tools",
114-
"data": {
115-
"server": "fake_mcp_server",
116-
"result": ["test_tool_1", "test_tool_2"],
117-
},
118-
},
119123
{
120124
"type": "function",
121125
"data": {
@@ -133,8 +137,15 @@ async def test_mcp_tracing():
133137
"mcp_data": {"server": "fake_mcp_server"},
134138
},
135139
},
140+
{
141+
"type": "mcp_tools",
142+
"data": {
143+
"server": "fake_mcp_server",
144+
"result": ["test_tool_1", "test_tool_2"],
145+
},
146+
},
136147
],
137-
}
148+
},
138149
],
139150
}
140151
]
@@ -165,6 +176,13 @@ async def test_mcp_tracing():
165176
{
166177
"workflow_name": "Agent workflow",
167178
"children": [
179+
{
180+
"type": "mcp_tools",
181+
"data": {
182+
"server": "fake_mcp_server",
183+
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
184+
},
185+
},
168186
{
169187
"type": "agent",
170188
"data": {
@@ -174,13 +192,6 @@ async def test_mcp_tracing():
174192
"output_type": "str",
175193
},
176194
"children": [
177-
{
178-
"type": "mcp_tools",
179-
"data": {
180-
"server": "fake_mcp_server",
181-
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
182-
},
183-
},
184195
{
185196
"type": "function",
186197
"data": {
@@ -190,8 +201,15 @@ async def test_mcp_tracing():
190201
"mcp_data": {"server": "fake_mcp_server"},
191202
},
192203
},
204+
{
205+
"type": "mcp_tools",
206+
"data": {
207+
"server": "fake_mcp_server",
208+
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
209+
},
210+
},
193211
],
194-
}
212+
},
195213
],
196214
}
197215
]

tests/models/test_litellm_extra_body.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ async def fake_acompletion(model, messages=None, **kwargs):
2626

2727
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
2828
settings = ModelSettings(
29-
temperature=0.1,
30-
extra_body={"cached_content": "some_cache", "foo": 123}
29+
temperature=0.1, extra_body={"cached_content": "some_cache", "foo": 123}
3130
)
3231
model = LitellmModel(model="test-model")
3332

tests/test_agent_runner.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,3 +745,38 @@ async def test_previous_response_id_passed_between_runs_streamed_multi_turn():
745745
pass
746746

747747
assert model.last_turn_args.get("previous_response_id") == "resp-stream-test"
748+
749+
750+
@pytest.mark.asyncio
751+
async def test_dynamic_tool_addition_run() -> None:
752+
"""Test that tools can be added to an agent during a run."""
753+
model = FakeModel()
754+
755+
executed: dict[str, bool] = {"called": False}
756+
757+
agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again")
758+
759+
@function_tool(name_override="tool2")
760+
def tool2() -> str:
761+
executed["called"] = True
762+
return "result2"
763+
764+
@function_tool(name_override="add_tool")
765+
async def add_tool() -> str:
766+
agent.tools.append(tool2)
767+
return "added"
768+
769+
agent.tools.append(add_tool)
770+
771+
model.add_multiple_turn_outputs(
772+
[
773+
[get_function_tool_call("add_tool", json.dumps({}))],
774+
[get_function_tool_call("tool2", json.dumps({}))],
775+
[get_text_message("done")],
776+
]
777+
)
778+
779+
result = await Runner.run(agent, input="start")
780+
781+
assert executed["called"] is True
782+
assert result.final_output == "done"

tests/test_agent_runner_streamed.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
RunContextWrapper,
1919
Runner,
2020
UserError,
21+
function_tool,
2122
handoff,
2223
)
2324
from agents.items import RunItem
@@ -684,3 +685,39 @@ async def test_streaming_events():
684685
assert len(agent_data) == 2, "should have 2 agent updated events"
685686
assert agent_data[0].new_agent == agent_2, "should have started with agent_2"
686687
assert agent_data[1].new_agent == agent_1, "should have handed off to agent_1"
688+
689+
690+
@pytest.mark.asyncio
691+
async def test_dynamic_tool_addition_run_streamed() -> None:
692+
model = FakeModel()
693+
694+
executed: dict[str, bool] = {"called": False}
695+
696+
agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again")
697+
698+
@function_tool(name_override="tool2")
699+
def tool2() -> str:
700+
executed["called"] = True
701+
return "result2"
702+
703+
@function_tool(name_override="add_tool")
704+
async def add_tool() -> str:
705+
agent.tools.append(tool2)
706+
return "added"
707+
708+
agent.tools.append(add_tool)
709+
710+
model.add_multiple_turn_outputs(
711+
[
712+
[get_function_tool_call("add_tool", json.dumps({}))],
713+
[get_function_tool_call("tool2", json.dumps({}))],
714+
[get_text_message("done")],
715+
]
716+
)
717+
718+
result = Runner.run_streamed(agent, input="start")
719+
async for _ in result.stream_events():
720+
pass
721+
722+
assert executed["called"] is True
723+
assert result.final_output == "done"

0 commit comments

Comments
 (0)