Skip to content

Commit 31dd691

Browse files
committed
test: add unit tests for keep_last_n_items function
1 parent aba7319 commit 31dd691

File tree

1 file changed

+117
-1
lines changed

1 file changed

+117
-1
lines changed

tests/test_extension_filters.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import pytest
12
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
23

34
from agents import Agent, HandoffInputData
4-
from agents.extensions.handoff_filters import remove_all_tools
5+
from agents.extensions.handoff_filters import remove_all_tools, keep_last_n_items
56
from agents.items import (
67
HandoffOutputItem,
78
MessageOutputItem,
@@ -186,3 +187,118 @@ def test_removes_handoffs_from_history():
186187
assert len(filtered_data.input_history) == 1
187188
assert len(filtered_data.pre_handoff_items) == 1
188189
assert len(filtered_data.new_items) == 1
190+
191+
192+
def test_keep_last_n_items_basic():
193+
"""Test the basic functionality of keep_last_n_items."""
194+
handoff_input_data = HandoffInputData(
195+
input_history=(
196+
_get_message_input_item("Message 1"),
197+
_get_message_input_item("Message 2"),
198+
_get_message_input_item("Message 3"),
199+
_get_message_input_item("Message 4"),
200+
_get_message_input_item("Message 5"),
201+
),
202+
pre_handoff_items=(_get_message_output_run_item("Pre handoff"),),
203+
new_items=(_get_message_output_run_item("New item"),),
204+
)
205+
206+
# Keep last 2 items
207+
filtered_data = keep_last_n_items(handoff_input_data, 2)
208+
209+
assert len(filtered_data.input_history) == 2
210+
assert filtered_data.input_history[-1] == _get_message_input_item("Message 5")
211+
assert filtered_data.input_history[-2] == _get_message_input_item("Message 4")
212+
213+
# Pre-handoff and new items should remain unchanged
214+
assert len(filtered_data.pre_handoff_items) == 1
215+
assert len(filtered_data.new_items) == 1
216+
217+
218+
def test_keep_last_n_items_with_tool_messages():
219+
"""Test keeping last N items while removing tool messages."""
220+
handoff_input_data = HandoffInputData(
221+
input_history=(
222+
_get_message_input_item("Message 1"),
223+
_get_function_result_input_item("Function result"),
224+
_get_message_input_item("Message 2"),
225+
_get_handoff_input_item("Handoff"),
226+
_get_message_input_item("Message 3"),
227+
),
228+
pre_handoff_items=(_get_message_output_run_item("Pre handoff"),),
229+
new_items=(_get_message_output_run_item("New item"),),
230+
)
231+
232+
# Keep last 2 items but remove tool messages first
233+
filtered_data = keep_last_n_items(handoff_input_data, 2, keep_tool_messages=False)
234+
235+
# Should have the last 2 non-tool messages
236+
assert len(filtered_data.input_history) == 2
237+
assert filtered_data.input_history[-1] == _get_message_input_item("Message 3")
238+
assert filtered_data.input_history[-2] == _get_message_input_item("Message 2")
239+
240+
241+
def test_keep_last_n_items_all():
242+
"""Test keeping more items than exist."""
243+
handoff_input_data = HandoffInputData(
244+
input_history=(
245+
_get_message_input_item("Message 1"),
246+
_get_message_input_item("Message 2"),
247+
),
248+
pre_handoff_items=(_get_message_output_run_item("Pre handoff"),),
249+
new_items=(_get_message_output_run_item("New item"),),
250+
)
251+
252+
# Request more items than exist
253+
filtered_data = keep_last_n_items(handoff_input_data, 10)
254+
255+
# Should keep all items
256+
assert len(filtered_data.input_history) == 2
257+
assert filtered_data.input_history == handoff_input_data.input_history
258+
259+
260+
def test_keep_last_n_items_with_string_history():
261+
"""Test handling of string input_history."""
262+
handoff_input_data = HandoffInputData(
263+
input_history="This is a string history",
264+
pre_handoff_items=(_get_message_output_run_item("Pre handoff"),),
265+
new_items=(_get_message_output_run_item("New item"),),
266+
)
267+
268+
# String history should be preserved
269+
filtered_data = keep_last_n_items(handoff_input_data, 3)
270+
271+
assert filtered_data.input_history == "This is a string history"
272+
273+
274+
def test_keep_last_n_items_invalid_input():
275+
"""Test error handling for invalid inputs."""
276+
handoff_input_data = HandoffInputData(
277+
input_history=(_get_message_input_item("Message 1"),),
278+
pre_handoff_items=(),
279+
new_items=(),
280+
)
281+
282+
# Test with invalid n values
283+
with pytest.raises(ValueError, match="n must be a positive integer"):
284+
keep_last_n_items(handoff_input_data, 0)
285+
286+
with pytest.raises(ValueError, match="n must be a positive integer"):
287+
keep_last_n_items(handoff_input_data, -5)
288+
289+
with pytest.raises(ValueError, match="n must be an integer"):
290+
keep_last_n_items(handoff_input_data, "3")
291+
292+
293+
def test_keep_last_n_items_empty_history():
294+
"""Test with an empty input history."""
295+
handoff_input_data = HandoffInputData(
296+
input_history=(),
297+
pre_handoff_items=(_get_message_output_run_item("Pre handoff"),),
298+
new_items=(_get_message_output_run_item("New item"),),
299+
)
300+
301+
# Empty history should remain empty
302+
filtered_data = keep_last_n_items(handoff_input_data, 3)
303+
304+
assert len(filtered_data.input_history) == 0

0 commit comments

Comments
 (0)