From 9238c283fcad130b98d3960f9e1384ac7e3887e5 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Thu, 24 Apr 2025 13:43:33 -0400 Subject: [PATCH] Allow previous_response_id to be passed to input guardrails --- src/agents/__init__.py | 6 + src/agents/_run_impl.py | 3 +- src/agents/guardrail.py | 56 ++++-- src/agents/run.py | 12 +- tests/{ => guardrails}/test_guardrails.py | 45 ++++- tests/guardrails/test_new_input_guardrails.py | 184 ++++++++++++++++++ 6 files changed, 281 insertions(+), 25 deletions(-) rename tests/{ => guardrails}/test_guardrails.py (86%) create mode 100644 tests/guardrails/test_new_input_guardrails.py diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6d7c90b4..bfdf3c52 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -18,7 +18,10 @@ ) from .guardrail import ( GuardrailFunctionOutput, + InputGuardailInputs, InputGuardrail, + InputGuardrailFunction, + InputGuardrailFunctionLegacy, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult, @@ -174,6 +177,9 @@ def enable_verbose_stdout_logging(): "OutputGuardrail", "OutputGuardrailResult", "GuardrailFunctionOutput", + "InputGuardailInputs", + "InputGuardrailFunction", + "InputGuardrailFunctionLegacy", "input_guardrail", "output_guardrail", "handoff", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index b5a83685..1a79612f 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -688,10 +688,11 @@ async def run_single_input_guardrail( agent: Agent[Any], guardrail: InputGuardrail[TContext], input: str | list[TResponseInputItem], + previous_response_id: str | None, context: RunContextWrapper[TContext], ) -> InputGuardrailResult: with guardrail_span(guardrail.get_name()) as span_guardrail: - result = await guardrail.run(agent, input, context) + result = await guardrail.run(agent, input, context, previous_response_id) span_guardrail.span_data.triggered = result.output.tripwire_triggered return result diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index a96f0f7d..caa60e39 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -3,9 +3,9 @@ import inspect from collections.abc import Awaitable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast, overload -from typing_extensions import TypeVar +from typing_extensions import TypeAlias, TypeVar from .exceptions import UserError from .items import TResponseInputItem @@ -68,6 +68,31 @@ class OutputGuardrailResult: """The output of the guardrail function.""" +InputGuardrailFunctionLegacy: TypeAlias = Callable[ + [RunContextWrapper[TContext], "Agent[Any]", Union[str, list[TResponseInputItem]]], + MaybeAwaitable[GuardrailFunctionOutput], +] +"""The legacy guardrail function signature, retained for backwards compatibility. Of the form: +def my_guardrail(ctx, agent, input) +""" + + +@dataclass +class InputGuardailInputs: + agent: Agent[Any] + input: str | list[TResponseInputItem] + previous_response_id: str | None + + +InputGuardrailFunction: TypeAlias = Callable[ + [RunContextWrapper[TContext], InputGuardailInputs], + MaybeAwaitable[GuardrailFunctionOutput], +] +"""The new guardrail function signature, of the form: +def my_guardrail(ctx, inputs) +""" + + @dataclass class InputGuardrail(Generic[TContext]): """Input guardrails are checks that run in parallel to the agent's execution. @@ -82,10 +107,7 @@ class InputGuardrail(Generic[TContext]): execution will immediately stop and a `InputGuardrailTripwireTriggered` exception will be raised """ - guardrail_function: Callable[ - [RunContextWrapper[TContext], Agent[Any], str | list[TResponseInputItem]], - MaybeAwaitable[GuardrailFunctionOutput], - ] + guardrail_function: InputGuardrailFunction[TContext] | InputGuardrailFunctionLegacy[TContext] """A function that receives the agent input and the context, and returns a `GuardrailResult`. The result marks whether the tripwire was triggered, and can optionally include information about the guardrail's output. @@ -107,11 +129,21 @@ async def run( agent: Agent[Any], input: str | list[TResponseInputItem], context: RunContextWrapper[TContext], + previous_response_id: str | None, ) -> InputGuardrailResult: if not callable(self.guardrail_function): raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") - output = self.guardrail_function(context, agent, input) + sig = inspect.signature(self.guardrail_function) + if len(sig.parameters) == 3: + # Legacy guardrail function + legacy_function = cast(InputGuardrailFunctionLegacy[TContext], self.guardrail_function) + output = legacy_function(context, agent, input) + else: + # New guardrail function + new_function = cast(InputGuardrailFunction[TContext], self.guardrail_function) + output = new_function(context, InputGuardailInputs(agent, input, previous_response_id)) + if inspect.isawaitable(output): return InputGuardrailResult( guardrail=self, @@ -182,13 +214,11 @@ async def run( TContext_co = TypeVar("TContext_co", bound=Any, covariant=True) # For InputGuardrail -_InputGuardrailFuncSync = Callable[ - [RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]], - GuardrailFunctionOutput, +_InputGuardrailFuncSync = Union[ + InputGuardrailFunctionLegacy[TContext_co], InputGuardrailFunction[TContext_co] ] -_InputGuardrailFuncAsync = Callable[ - [RunContextWrapper[TContext_co], "Agent[Any]", Union[str, list[TResponseInputItem]]], - Awaitable[GuardrailFunctionOutput], +_InputGuardrailFuncAsync = Union[ + InputGuardrailFunctionLegacy[TContext_co], InputGuardrailFunction[TContext_co] ] diff --git a/src/agents/run.py b/src/agents/run.py index 2af558d5..0ad93d71 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -221,6 +221,7 @@ async def run( starting_agent.input_guardrails + (run_config.input_guardrails or []), copy.deepcopy(input), + previous_response_id, context_wrapper, ), cls._run_single_turn( @@ -446,6 +447,7 @@ async def _run_input_guardrails_with_queue( agent: Agent[Any], guardrails: list[InputGuardrail[TContext]], input: str | list[TResponseInputItem], + previous_response_id: str | None, context: RunContextWrapper[TContext], streamed_result: RunResultStreaming, parent_span: Span[Any], @@ -455,7 +457,9 @@ async def _run_input_guardrails_with_queue( # We'll run the guardrails and push them onto the queue as they complete guardrail_tasks = [ asyncio.create_task( - RunImpl.run_single_input_guardrail(agent, guardrail, input, context) + RunImpl.run_single_input_guardrail( + agent, guardrail, input, previous_response_id, context + ) ) for guardrail in guardrails ] @@ -551,6 +555,7 @@ async def _run_streamed_impl( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), copy.deepcopy(ItemHelpers.input_to_new_input_list(starting_input)), + previous_response_id, context_wrapper, streamed_result, current_span, @@ -825,6 +830,7 @@ async def _run_input_guardrails( agent: Agent[Any], guardrails: list[InputGuardrail[TContext]], input: str | list[TResponseInputItem], + previous_response_id: str | None, context: RunContextWrapper[TContext], ) -> list[InputGuardrailResult]: if not guardrails: @@ -832,7 +838,9 @@ async def _run_input_guardrails( guardrail_tasks = [ asyncio.create_task( - RunImpl.run_single_input_guardrail(agent, guardrail, input, context) + RunImpl.run_single_input_guardrail( + agent, guardrail, input, previous_response_id, context + ) ) for guardrail in guardrails ] diff --git a/tests/test_guardrails.py b/tests/guardrails/test_guardrails.py similarity index 86% rename from tests/test_guardrails.py rename to tests/guardrails/test_guardrails.py index c9f318c3..bb37c880 100644 --- a/tests/test_guardrails.py +++ b/tests/guardrails/test_guardrails.py @@ -32,14 +32,20 @@ def sync_guardrail( async def test_sync_input_guardrail(): guardrail = InputGuardrail(guardrail_function=get_sync_guardrail(triggers=False)) result = await guardrail.run( - agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, ) assert not result.output.tripwire_triggered assert result.output.output_info is None guardrail = InputGuardrail(guardrail_function=get_sync_guardrail(triggers=True)) result = await guardrail.run( - agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, ) assert result.output.tripwire_triggered assert result.output.output_info is None @@ -48,7 +54,10 @@ async def test_sync_input_guardrail(): guardrail_function=get_sync_guardrail(triggers=True, output_info="test") ) result = await guardrail.run( - agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, ) assert result.output.tripwire_triggered assert result.output.output_info == "test" @@ -70,14 +79,20 @@ async def async_guardrail( async def test_async_input_guardrail(): guardrail = InputGuardrail(guardrail_function=get_async_input_guardrail(triggers=False)) result = await guardrail.run( - agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, ) assert not result.output.tripwire_triggered assert result.output.output_info is None guardrail = InputGuardrail(guardrail_function=get_async_input_guardrail(triggers=True)) result = await guardrail.run( - agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, ) assert result.output.tripwire_triggered assert result.output.output_info is None @@ -86,7 +101,10 @@ async def test_async_input_guardrail(): guardrail_function=get_async_input_guardrail(triggers=True, output_info="test") ) result = await guardrail.run( - agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, ) assert result.output.tripwire_triggered assert result.output.output_info == "test" @@ -98,7 +116,10 @@ async def test_invalid_input_guardrail_raises_user_error(): # Purposely ignoring type error guardrail = InputGuardrail(guardrail_function="foo") # type: ignore await guardrail.run( - agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, ) @@ -210,14 +231,20 @@ def decorated_named_input_guardrail( async def test_input_guardrail_decorators(): guardrail = decorated_input_guardrail result = await guardrail.run( - agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + input="test", + previous_response_id=None, + context=RunContextWrapper(context=None), ) assert not result.output.tripwire_triggered assert result.output.output_info == "test_1" guardrail = decorated_named_input_guardrail result = await guardrail.run( - agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None) + agent=Agent(name="test"), + input="test", + previous_response_id=None, + context=RunContextWrapper(context=None), ) assert not result.output.tripwire_triggered assert result.output.output_info == "test_2" diff --git a/tests/guardrails/test_new_input_guardrails.py b/tests/guardrails/test_new_input_guardrails.py new file mode 100644 index 00000000..0271aa5b --- /dev/null +++ b/tests/guardrails/test_new_input_guardrails.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from agents import ( + Agent, + GuardrailFunctionOutput, + InputGuardailInputs, + InputGuardrail, + RunContextWrapper, + TResponseInputItem, +) +from agents.guardrail import input_guardrail + + +def get_sync_guardrail(triggers: bool, output_info: Any | None = None): + def sync_guardrail(context: RunContextWrapper[Any], inputs: InputGuardailInputs): + assert inputs.agent is not None + assert inputs.input is not None + + return GuardrailFunctionOutput( + output_info=output_info, + tripwire_triggered=triggers, + ) + + return sync_guardrail + + +@pytest.mark.asyncio +async def test_sync_input_guardrail(): + guardrail = InputGuardrail(guardrail_function=get_sync_guardrail(triggers=False)) + result = await guardrail.run( + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, + ) + assert not result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = InputGuardrail(guardrail_function=get_sync_guardrail(triggers=True)) + result = await guardrail.run( + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, + ) + assert result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = InputGuardrail( + guardrail_function=get_sync_guardrail(triggers=True, output_info="test") + ) + result = await guardrail.run( + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, + ) + assert result.output.tripwire_triggered + assert result.output.output_info == "test" + + +def get_async_input_guardrail(triggers: bool, output_info: Any | None = None): + async def async_guardrail(context: RunContextWrapper[Any], inputs: InputGuardailInputs): + assert inputs.agent is not None + assert inputs.input is not None + + return GuardrailFunctionOutput( + output_info=output_info, + tripwire_triggered=triggers, + ) + + return async_guardrail + + +@pytest.mark.asyncio +async def test_async_input_guardrail(): + guardrail = InputGuardrail(guardrail_function=get_async_input_guardrail(triggers=False)) + result = await guardrail.run( + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, + ) + assert not result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = InputGuardrail(guardrail_function=get_async_input_guardrail(triggers=True)) + result = await guardrail.run( + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, + ) + assert result.output.tripwire_triggered + assert result.output.output_info is None + + guardrail = InputGuardrail( + guardrail_function=get_async_input_guardrail(triggers=True, output_info="test") + ) + result = await guardrail.run( + agent=Agent(name="test"), + input="test", + context=RunContextWrapper(context=None), + previous_response_id=None, + ) + assert result.output.tripwire_triggered + assert result.output.output_info == "test" + + +@input_guardrail +def decorated_input_guardrail( + context: RunContextWrapper[Any], inputs: InputGuardailInputs +) -> GuardrailFunctionOutput: + assert inputs.agent is not None + assert inputs.input is not None + + return GuardrailFunctionOutput( + output_info="test_1", + tripwire_triggered=False, + ) + + +@input_guardrail(name="Custom name") +def decorated_named_input_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] +) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput( + output_info="test_2", + tripwire_triggered=False, + ) + + +@pytest.mark.asyncio +async def test_input_guardrail_decorators(): + guardrail = decorated_input_guardrail + result = await guardrail.run( + agent=Agent(name="test"), + input="test", + previous_response_id=None, + context=RunContextWrapper(context=None), + ) + assert not result.output.tripwire_triggered + assert result.output.output_info == "test_1" + + guardrail = decorated_named_input_guardrail + result = await guardrail.run( + agent=Agent(name="test"), + input="test", + previous_response_id=None, + context=RunContextWrapper(context=None), + ) + assert not result.output.tripwire_triggered + assert result.output.output_info == "test_2" + assert guardrail.get_name() == "Custom name" + + +@input_guardrail +def guardrail_with_previous_response_id( + context: RunContextWrapper[Any], inputs: InputGuardailInputs +) -> GuardrailFunctionOutput: + assert inputs.agent is not None + assert inputs.input is not None + assert inputs.previous_response_id is not None + return GuardrailFunctionOutput( + output_info="test_3", + tripwire_triggered=False, + ) + + +@pytest.mark.asyncio +async def test_guardrail_with_previous_response_id(): + guardrail = guardrail_with_previous_response_id + result = await guardrail.run( + agent=Agent(name="test"), + input="test", + previous_response_id="test", + context=RunContextWrapper(context=None), + ) + assert not result.output.tripwire_triggered + assert result.output.output_info == "test_3"