diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index dc672acd..fc42ba50 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -228,6 +228,10 @@ async def _fetch_response( "role": "system", }, ) + + if model_settings.patch_json_mode: + Converter.patch_messages_json_mode(converted_messages, output_schema) + if tracing.include_data(): span.span_data.input = converted_messages @@ -239,7 +243,7 @@ async def _fetch_response( else None ) tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) - response_format = Converter.convert_response_format(output_schema) + response_format = Converter.convert_response_format(output_schema, model_settings.patch_json_mode) converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else [] diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 7b016c98..394f8bfb 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -73,6 +73,8 @@ class ModelSettings: """Additional headers to provide with the request. Defaults to None if not provided.""" + patch_json_mode: bool = False + def resolve(self, override: ModelSettings | None) -> ModelSettings: """Produce a new ModelSettings by overlaying any non-None values from the override on top of this instance.""" diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 613a3745..894f65d3 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -2,6 +2,7 @@ import json from collections.abc import Iterable +from textwrap import dedent from typing import Any, Literal, cast from openai import NOT_GIVEN, NotGiven @@ -67,11 +68,14 @@ def convert_tool_choice( @classmethod def convert_response_format( - cls, final_output_schema: AgentOutputSchemaBase | None + cls, final_output_schema: AgentOutputSchemaBase | None, patch_json_mode: bool = False, ) -> ResponseFormat | NotGiven: if not final_output_schema or final_output_schema.is_plain_text(): return NOT_GIVEN + if patch_json_mode: + return {"type": "json_object"} + return { "type": "json_schema", "json_schema": { @@ -437,6 +441,38 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam: flush_assistant_message() return result + @classmethod + def patch_messages_json_mode( + cls, + items: list[ChatCompletionMessageParam], + output_schema: AgentOutputSchemaBase | None, + ): + """ + In json-mode, add json_schema to the system message. + """ + if output_schema is None or output_schema.is_plain_text(): + return + + message = dedent(f""" + As a genius expert, your task is to understand the content and provide + the parsed objects in json that match the following json_schema:\n + + {json.dumps(output_schema.json_schema(), indent=2, ensure_ascii=False)} + + Make sure to return an instance of the JSON, not the schema itself + """) + + if items[0]["role"] != "system": + items.insert(0, ChatCompletionSystemMessageParam(content=message, role="system")) + elif isinstance(items[0]["content"], str): + items[0]["content"] += f"\n\n{message}" + elif isinstance(items[0]["content"], list): + items[0]["content"][0]["text"] += f"\n\n{message}" + else: + raise ValueError( + "Invalid message format, must be a string or a list of messages" + ) + @classmethod def tool_to_openai(cls, tool: Tool) -> ChatCompletionToolParam: if isinstance(tool, FunctionTool): diff --git a/src/agents/models/json.py b/src/agents/models/json.py new file mode 100644 index 00000000..5a9bedc3 --- /dev/null +++ b/src/agents/models/json.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +from ..agent_output import AgentOutputSchemaBase +from ..handoffs import Handoff +from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent +from ..tool import Tool +from .interface import ModelTracing + +if TYPE_CHECKING: + from ..model_settings import ModelSettings + +class JSONModeModelMixin: + ''' + A mixin class for JSON mode models. + ''' + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + ) -> ModelResponse: + pass + + def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + ) -> AsyncIterator[TResponseStreamEvent]: + pass