diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index b5a83685..d1b1d6e1 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -720,7 +720,7 @@ def stream_step_result_to_queue( elif isinstance(item, HandoffCallItem): event = RunItemStreamEvent(item=item, name="handoff_requested") elif isinstance(item, HandoffOutputItem): - event = RunItemStreamEvent(item=item, name="handoff_occured") + event = RunItemStreamEvent(item=item, name="handoff_occurred") elif isinstance(item, ToolCallItem): event = RunItemStreamEvent(item=item, name="tool_called") elif isinstance(item, ToolCallOutputItem): diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 066bbd83..ee14956e 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -38,7 +38,7 @@ def json_schema(self) -> dict[str, Any]: @abc.abstractmethod def is_strict_json_schema(self) -> bool: """Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema - features, but guarantees valis JSON. See here for details: + features, but guarantees valid JSON. See here for details: https://platform.openai.com/docs/guides/structured-outputs#supported-schemas """ pass diff --git a/src/agents/extensions/handoff_filters.py b/src/agents/extensions/handoff_filters.py index f4f9b8bf..4ee0e928 100644 --- a/src/agents/extensions/handoff_filters.py +++ b/src/agents/extensions/handoff_filters.py @@ -65,3 +65,46 @@ def _remove_tool_types_from_input( continue filtered_items.append(item) return tuple(filtered_items) + + +def keep_last_n_items( + handoff_input_data: HandoffInputData, + n: int, + keep_tool_messages: bool = True +) -> HandoffInputData: + """ + Keep only the last n items in the input history. + If keep_tool_messages is False, remove tool messages first. + + Args: + handoff_input_data: The input data to filter + n: Number of items to keep from the end. Must be a positive integer. + If n is 1, only the last item is kept. + If n is greater than the number of items, all items are kept. + If n is less than or equal to 0, it raises a ValueError. + keep_tool_messages: If False, removes tool messages before filtering + + Raises: + ValueError: If n is not a positive integer + """ + if not isinstance(n, int): + raise ValueError(f"n must be an integer, got {type(n).__name__}") + if n <= 0: + raise ValueError(f"n must be a positive integer, got {n}") + + data = handoff_input_data + if not keep_tool_messages: + data = remove_all_tools(data) + + # Always ensure input_history and new_items are tuples for consistent slicing and return + history = ( + tuple(data.input_history)[-n:] + if isinstance(data.input_history, tuple) + else data.input_history + ) + + return HandoffInputData( + input_history=history, + pre_handoff_items=tuple(data.pre_handoff_items), + new_items=tuple(data.new_items), + ) diff --git a/src/agents/stream_events.py b/src/agents/stream_events.py index bd37d11f..eff345b9 100644 --- a/src/agents/stream_events.py +++ b/src/agents/stream_events.py @@ -31,7 +31,7 @@ class RunItemStreamEvent: name: Literal[ "message_output_created", "handoff_requested", - "handoff_occured", + "handoff_occurred", "tool_called", "tool_output", "reasoning_item_created", diff --git a/tests/test_extension_filters.py b/tests/test_extension_filters.py index 4cb017aa..b42c2f92 100644 --- a/tests/test_extension_filters.py +++ b/tests/test_extension_filters.py @@ -1,7 +1,8 @@ +import pytest from openai.types.responses import ResponseOutputMessage, ResponseOutputText from agents import Agent, HandoffInputData -from agents.extensions.handoff_filters import remove_all_tools +from agents.extensions.handoff_filters import remove_all_tools, keep_last_n_items from agents.items import ( HandoffOutputItem, MessageOutputItem, @@ -186,3 +187,118 @@ def test_removes_handoffs_from_history(): assert len(filtered_data.input_history) == 1 assert len(filtered_data.pre_handoff_items) == 1 assert len(filtered_data.new_items) == 1 + + +def test_keep_last_n_items_basic(): + """Test the basic functionality of keep_last_n_items.""" + handoff_input_data = HandoffInputData( + input_history=( + _get_message_input_item("Message 1"), + _get_message_input_item("Message 2"), + _get_message_input_item("Message 3"), + _get_message_input_item("Message 4"), + _get_message_input_item("Message 5"), + ), + pre_handoff_items=(_get_message_output_run_item("Pre handoff"),), + new_items=(_get_message_output_run_item("New item"),), + ) + + # Keep last 2 items + filtered_data = keep_last_n_items(handoff_input_data, 2) + + assert len(filtered_data.input_history) == 2 + assert filtered_data.input_history[-1] == _get_message_input_item("Message 5") + assert filtered_data.input_history[-2] == _get_message_input_item("Message 4") + + # Pre-handoff and new items should remain unchanged + assert len(filtered_data.pre_handoff_items) == 1 + assert len(filtered_data.new_items) == 1 + + +def test_keep_last_n_items_with_tool_messages(): + """Test keeping last N items while removing tool messages.""" + handoff_input_data = HandoffInputData( + input_history=( + _get_message_input_item("Message 1"), + _get_function_result_input_item("Function result"), + _get_message_input_item("Message 2"), + _get_handoff_input_item("Handoff"), + _get_message_input_item("Message 3"), + ), + pre_handoff_items=(_get_message_output_run_item("Pre handoff"),), + new_items=(_get_message_output_run_item("New item"),), + ) + + # Keep last 2 items but remove tool messages first + filtered_data = keep_last_n_items(handoff_input_data, 2, keep_tool_messages=False) + + # Should have the last 2 non-tool messages + assert len(filtered_data.input_history) == 2 + assert filtered_data.input_history[-1] == _get_message_input_item("Message 3") + assert filtered_data.input_history[-2] == _get_message_input_item("Message 2") + + +def test_keep_last_n_items_all(): + """Test keeping more items than exist.""" + handoff_input_data = HandoffInputData( + input_history=( + _get_message_input_item("Message 1"), + _get_message_input_item("Message 2"), + ), + pre_handoff_items=(_get_message_output_run_item("Pre handoff"),), + new_items=(_get_message_output_run_item("New item"),), + ) + + # Request more items than exist + filtered_data = keep_last_n_items(handoff_input_data, 10) + + # Should keep all items + assert len(filtered_data.input_history) == 2 + assert filtered_data.input_history == handoff_input_data.input_history + + +def test_keep_last_n_items_with_string_history(): + """Test handling of string input_history.""" + handoff_input_data = HandoffInputData( + input_history="This is a string history", + pre_handoff_items=(_get_message_output_run_item("Pre handoff"),), + new_items=(_get_message_output_run_item("New item"),), + ) + + # String history should be preserved + filtered_data = keep_last_n_items(handoff_input_data, 3) + + assert filtered_data.input_history == "This is a string history" + + +def test_keep_last_n_items_invalid_input(): + """Test error handling for invalid inputs.""" + handoff_input_data = HandoffInputData( + input_history=(_get_message_input_item("Message 1"),), + pre_handoff_items=(), + new_items=(), + ) + + # Test with invalid n values + with pytest.raises(ValueError, match="n must be a positive integer"): + keep_last_n_items(handoff_input_data, 0) + + with pytest.raises(ValueError, match="n must be a positive integer"): + keep_last_n_items(handoff_input_data, -5) + + with pytest.raises(ValueError, match="n must be an integer"): + keep_last_n_items(handoff_input_data, "3") + + +def test_keep_last_n_items_empty_history(): + """Test with an empty input history.""" + handoff_input_data = HandoffInputData( + input_history=(), + pre_handoff_items=(_get_message_output_run_item("Pre handoff"),), + new_items=(_get_message_output_run_item("New item"),), + ) + + # Empty history should remain empty + filtered_data = keep_last_n_items(handoff_input_data, 3) + + assert len(filtered_data.input_history) == 0