Skip to content

Commit b63a157

Browse files
committed
Don't cache agent tools during a run
### Summary: Towards #767. We were caching the list of tools for an agent, so if you did `agent.tools.append(...)` from a tool call, the next call to the model wouldn't include the new tool. THis is a bug. ### Test Plan: Unit tests. Note that now MCP tools are listed each time the agent runs (users can still cache the `list_tools` however).
1 parent 775d3e2 commit b63a157

File tree

4 files changed

+120
-33
lines changed

4 files changed

+120
-33
lines changed

src/agents/run.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from __future__ import annotations
32

43
import asyncio
@@ -182,6 +181,8 @@ async def run(
182181

183182
try:
184183
while True:
184+
all_tools = await cls._get_all_tools(current_agent)
185+
185186
# Start an agent span if we don't have one. This span is ended if the current
186187
# agent changes, or if the agent loop ends.
187188
if current_span is None:
@@ -197,8 +198,6 @@ async def run(
197198
output_type=output_type_name,
198199
)
199200
current_span.start(mark_as_current=True)
200-
201-
all_tools = await cls._get_all_tools(current_agent)
202201
current_span.span_data.tools = [t.name for t in all_tools]
203202

204203
current_turn += 1
@@ -210,9 +209,7 @@ async def run(
210209
data={"max_turns": max_turns},
211210
),
212211
)
213-
raise MaxTurnsExceeded(
214-
f"Max turns ({max_turns}) exceeded"
215-
)
212+
raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded")
216213

217214
logger.debug(
218215
f"Running agent {current_agent.name} (turn {current_turn})",
@@ -295,7 +292,7 @@ async def run(
295292
last_agent=current_agent,
296293
context_wrapper=context_wrapper,
297294
input_guardrail_results=input_guardrail_results,
298-
output_guardrail_results=[]
295+
output_guardrail_results=[],
299296
)
300297
raise
301298
finally:
@@ -528,6 +525,8 @@ async def _run_streamed_impl(
528525
if streamed_result.is_complete:
529526
break
530527

528+
all_tools = await cls._get_all_tools(current_agent)
529+
531530
# Start an agent span if we don't have one. This span is ended if the current
532531
# agent changes, or if the agent loop ends.
533532
if current_span is None:
@@ -543,8 +542,6 @@ async def _run_streamed_impl(
543542
output_type=output_type_name,
544543
)
545544
current_span.start(mark_as_current=True)
546-
547-
all_tools = await cls._get_all_tools(current_agent)
548545
tool_names = [t.name for t in all_tools]
549546
current_span.span_data.tools = tool_names
550547
current_turn += 1

tests/mcp/test_mcp_tracing.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ async def test_mcp_tracing():
4444
{
4545
"workflow_name": "Agent workflow",
4646
"children": [
47+
{
48+
"type": "mcp_tools",
49+
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
50+
},
4751
{
4852
"type": "agent",
4953
"data": {
@@ -53,21 +57,21 @@ async def test_mcp_tracing():
5357
"output_type": "str",
5458
},
5559
"children": [
56-
{
57-
"type": "mcp_tools",
58-
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
59-
},
6060
{
6161
"type": "function",
6262
"data": {
6363
"name": "test_tool_1",
6464
"input": "",
65-
"output": '{"type":"text","text":"result_test_tool_1_{}","annotations":null}', # noqa: E501
65+
"output": '{"type":"text","text":"result_test_tool_1_{}","annotations":null}',
6666
"mcp_data": {"server": "fake_mcp_server"},
6767
},
6868
},
69+
{
70+
"type": "mcp_tools",
71+
"data": {"server": "fake_mcp_server", "result": ["test_tool_1"]},
72+
},
6973
],
70-
}
74+
},
7175
],
7276
}
7377
]
@@ -100,6 +104,13 @@ async def test_mcp_tracing():
100104
{
101105
"workflow_name": "Agent workflow",
102106
"children": [
107+
{
108+
"type": "mcp_tools",
109+
"data": {
110+
"server": "fake_mcp_server",
111+
"result": ["test_tool_1", "test_tool_2"],
112+
},
113+
},
103114
{
104115
"type": "agent",
105116
"data": {
@@ -109,13 +120,6 @@ async def test_mcp_tracing():
109120
"output_type": "str",
110121
},
111122
"children": [
112-
{
113-
"type": "mcp_tools",
114-
"data": {
115-
"server": "fake_mcp_server",
116-
"result": ["test_tool_1", "test_tool_2"],
117-
},
118-
},
119123
{
120124
"type": "function",
121125
"data": {
@@ -129,12 +133,19 @@ async def test_mcp_tracing():
129133
"data": {
130134
"name": "test_tool_2",
131135
"input": "",
132-
"output": '{"type":"text","text":"result_test_tool_2_{}","annotations":null}', # noqa: E501
136+
"output": '{"type":"text","text":"result_test_tool_2_{}","annotations":null}',
133137
"mcp_data": {"server": "fake_mcp_server"},
134138
},
135139
},
140+
{
141+
"type": "mcp_tools",
142+
"data": {
143+
"server": "fake_mcp_server",
144+
"result": ["test_tool_1", "test_tool_2"],
145+
},
146+
},
136147
],
137-
}
148+
},
138149
],
139150
}
140151
]
@@ -165,6 +176,13 @@ async def test_mcp_tracing():
165176
{
166177
"workflow_name": "Agent workflow",
167178
"children": [
179+
{
180+
"type": "mcp_tools",
181+
"data": {
182+
"server": "fake_mcp_server",
183+
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
184+
},
185+
},
168186
{
169187
"type": "agent",
170188
"data": {
@@ -174,24 +192,24 @@ async def test_mcp_tracing():
174192
"output_type": "str",
175193
},
176194
"children": [
177-
{
178-
"type": "mcp_tools",
179-
"data": {
180-
"server": "fake_mcp_server",
181-
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
182-
},
183-
},
184195
{
185196
"type": "function",
186197
"data": {
187198
"name": "test_tool_3",
188199
"input": "",
189-
"output": '{"type":"text","text":"result_test_tool_3_{}","annotations":null}', # noqa: E501
200+
"output": '{"type":"text","text":"result_test_tool_3_{}","annotations":null}',
190201
"mcp_data": {"server": "fake_mcp_server"},
191202
},
192203
},
204+
{
205+
"type": "mcp_tools",
206+
"data": {
207+
"server": "fake_mcp_server",
208+
"result": ["test_tool_1", "test_tool_2", "test_tool_3"],
209+
},
210+
},
193211
],
194-
}
212+
},
195213
],
196214
}
197215
]

tests/test_agent_runner.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,3 +745,38 @@ async def test_previous_response_id_passed_between_runs_streamed_multi_turn():
745745
pass
746746

747747
assert model.last_turn_args.get("previous_response_id") == "resp-stream-test"
748+
749+
750+
@pytest.mark.asyncio
751+
async def test_dynamic_tool_addition_run() -> None:
752+
"""Test that tools can be added to an agent during a run."""
753+
model = FakeModel()
754+
755+
executed: dict[str, bool] = {"called": False}
756+
757+
agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again")
758+
759+
@function_tool(name_override="tool2")
760+
def tool2() -> str:
761+
executed["called"] = True
762+
return "result2"
763+
764+
@function_tool(name_override="add_tool")
765+
async def add_tool() -> str:
766+
agent.tools.append(tool2)
767+
return "added"
768+
769+
agent.tools.append(add_tool)
770+
771+
model.add_multiple_turn_outputs(
772+
[
773+
[get_function_tool_call("add_tool", json.dumps({}))],
774+
[get_function_tool_call("tool2", json.dumps({}))],
775+
[get_text_message("done")],
776+
]
777+
)
778+
779+
result = await Runner.run(agent, input="start")
780+
781+
assert executed["called"] is True
782+
assert result.final_output == "done"

tests/test_agent_runner_streamed.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
RunContextWrapper,
1919
Runner,
2020
UserError,
21+
function_tool,
2122
handoff,
2223
)
2324
from agents.items import RunItem
@@ -684,3 +685,39 @@ async def test_streaming_events():
684685
assert len(agent_data) == 2, "should have 2 agent updated events"
685686
assert agent_data[0].new_agent == agent_2, "should have started with agent_2"
686687
assert agent_data[1].new_agent == agent_1, "should have handed off to agent_1"
688+
689+
690+
@pytest.mark.asyncio
691+
async def test_dynamic_tool_addition_run_streamed() -> None:
692+
model = FakeModel()
693+
694+
executed: dict[str, bool] = {"called": False}
695+
696+
agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again")
697+
698+
@function_tool(name_override="tool2")
699+
def tool2() -> str:
700+
executed["called"] = True
701+
return "result2"
702+
703+
@function_tool(name_override="add_tool")
704+
async def add_tool() -> str:
705+
agent.tools.append(tool2)
706+
return "added"
707+
708+
agent.tools.append(add_tool)
709+
710+
model.add_multiple_turn_outputs(
711+
[
712+
[get_function_tool_call("add_tool", json.dumps({}))],
713+
[get_function_tool_call("tool2", json.dumps({}))],
714+
[get_text_message("done")],
715+
]
716+
)
717+
718+
result = Runner.run_streamed(agent, input="start")
719+
async for _ in result.stream_events():
720+
pass
721+
722+
assert executed["called"] is True
723+
assert result.final_output == "done"

0 commit comments

Comments
 (0)