diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index dc672acd..8f7e87cb 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -78,7 +78,7 @@ async def get_response( | {"base_url": str(self.base_url or ""), "model_impl": "litellm"}, disabled=tracing.is_disabled(), ) as span_generation: - response = await self._fetch_response( + response, stream = await self._fetch_response( system_instructions, input, model_settings, @@ -87,16 +87,20 @@ async def get_response( handoffs, span_generation, tracing, - stream=False, + stream=True, ) - assert isinstance(response.choices[0], litellm.types.utils.Choices) + async for chunk in ChatCmplStreamHandler.handle_stream(response, stream): + if chunk.type == "response.completed": + response = chunk.response + + message = Converter.output_items_to_message(response.output) if _debug.DONT_LOG_MODEL_DATA: logger.debug("Received model response") else: logger.debug( - f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n" + f"LLM resp:\n{json.dumps(message.model_dump(), indent=2)}\n" ) if hasattr(response, "usage"): @@ -104,9 +108,9 @@ async def get_response( usage = ( Usage( requests=1, - input_tokens=response_usage.prompt_tokens, - output_tokens=response_usage.completion_tokens, - total_tokens=response_usage.total_tokens, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + total_tokens=response.usage.total_tokens, ) if response.usage else Usage() @@ -116,18 +120,14 @@ async def get_response( logger.warning("No usage information returned from Litellm") if tracing.include_data(): - span_generation.span_data.output = [response.choices[0].message.model_dump()] + span_generation.span_data.output = [message.model_dump()] span_generation.span_data.usage = { "input_tokens": usage.input_tokens, "output_tokens": usage.output_tokens, } - items = Converter.message_to_output_items( - LitellmConverter.convert_message_to_openai(response.choices[0].message) - ) - return ModelResponse( - output=items, + output=response.output, usage=usage, response_id=None, ) diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 613a3745..c1a7bdff 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -20,6 +20,7 @@ ChatCompletionUserMessageParam, ) from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function from openai.types.chat.completion_create_params import ResponseFormat from openai.types.responses import ( EasyInputMessageParam, @@ -81,6 +82,34 @@ def convert_response_format( }, } + @classmethod + def output_items_to_message(cls, items: list[TResponseOutputItem] ) -> ChatCompletionMessage: + tool_calls: list[ChatCompletionMessageToolCall] | None = None + message = ChatCompletionMessage(role="assistant") + + for item in items: + if isinstance(item, ResponseOutputMessage): + if isinstance(item.content, ResponseOutputText): + message.content = item.content.text + elif isinstance(item.content, ResponseOutputRefusal): + message.refusal = item.content.refusal + elif isinstance(item, ResponseFunctionToolCall): + if tool_calls is None: + tool_calls = [] + tool_calls.append( + ChatCompletionMessageToolCall( + id=item.call_id, + type="function", + function=Function( + name=item.name, + arguments=item.arguments, + ), + ) + ) + + message.tool_calls = tool_calls + return message + @classmethod def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TResponseOutputItem]: items: list[TResponseOutputItem] = [] diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 89619f83..59c1841d 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -7,7 +7,7 @@ from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream from openai.types import ChatModel -from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from openai.types.responses import Response from .. import _debug @@ -58,7 +58,7 @@ async def get_response( model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: - response = await self._fetch_response( + response, stream = await self._fetch_response( system_instructions, input, model_settings, @@ -67,37 +67,41 @@ async def get_response( handoffs, span_generation, tracing, - stream=False, + stream=True, ) + async for chunk in ChatCmplStreamHandler.handle_stream(response, stream): + if chunk.type == "response.completed": + response = chunk.response + + message = Converter.output_items_to_message(response.output) + if _debug.DONT_LOG_MODEL_DATA: logger.debug("Received model response") else: logger.debug( - f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n" + f"LLM resp:\n{json.dumps(message.model_dump(), indent=2)}\n" ) usage = ( Usage( requests=1, - input_tokens=response.usage.prompt_tokens, - output_tokens=response.usage.completion_tokens, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, total_tokens=response.usage.total_tokens, ) if response.usage else Usage() ) if tracing.include_data(): - span_generation.span_data.output = [response.choices[0].message.model_dump()] + span_generation.span_data.output = [message.model_dump()] span_generation.span_data.usage = { "input_tokens": usage.input_tokens, "output_tokens": usage.output_tokens, } - items = Converter.message_to_output_items(response.choices[0].message) - return ModelResponse( - output=items, + output=response.output, usage=usage, response_id=None, )