|  | 
|  | 1 | +import asyncio | 
| 1 | 2 | import json | 
| 2 | 3 | import logging | 
| 3 | 4 | import os | 
| 4 | 5 | from typing import Any, Dict, List, Optional | 
| 5 | 6 | 
 | 
| 6 |  | -import openai_multi_tool_use_parallel_patch  # type: ignore  # noqa: F401 | 
|  | 7 | +# todo: need to support this for multi tool use, maybe upstream package has it fixed now. | 
|  | 8 | +# commented out because it's not working with streams | 
|  | 9 | +# import openai_multi_tool_use_parallel_patch  # type: ignore  # noqa: F401 | 
| 7 | 10 | from abcs.llm import LLM | 
| 8 |  | -from abcs.models import PromptResponse, UsageStats | 
|  | 11 | +from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats | 
| 9 | 12 | from openai import OpenAI | 
| 10 | 13 | from tools.tool_manager import ToolManager | 
| 11 | 14 | 
 | 
| @@ -188,3 +191,59 @@ def _translate_response(self, response) -> PromptResponse: | 
| 188 | 191 |             # logger.error("An error occurred while translating OpenAI response: %s", e, exc_info=True) | 
| 189 | 192 |             logger.exception(f"error: {e}\nresponse: {response}") | 
| 190 | 193 |             raise e | 
|  | 194 | + | 
|  | 195 | +    # https://cookbook.openai.com/examples/how_to_stream_completions | 
|  | 196 | +    async def generate_text_stream( | 
|  | 197 | +        self, | 
|  | 198 | +        prompt: str, | 
|  | 199 | +        past_messages: List[Dict[str, str]], | 
|  | 200 | +        tools: Optional[List[Dict[str, Any]]] = None, | 
|  | 201 | +        **kwargs, | 
|  | 202 | +    ) -> StreamingPromptResponse: | 
|  | 203 | +        system_message = [{"role": "system", "content": self.system_prompt}] if self.system_prompt else [] | 
|  | 204 | +        combined_history = system_message + past_messages + [{"role": "user", "content": prompt}] | 
|  | 205 | + | 
|  | 206 | +        try: | 
|  | 207 | +            stream = self.client.chat.completions.create( | 
|  | 208 | +                model=self.model, | 
|  | 209 | +                messages=combined_history, | 
|  | 210 | +                tools=tools, | 
|  | 211 | +                stream=True, | 
|  | 212 | +            ) | 
|  | 213 | + | 
|  | 214 | +            async def content_generator(): | 
|  | 215 | +                for event in stream: | 
|  | 216 | +                    # print("HERE\n"*30) | 
|  | 217 | +                    # print(event) | 
|  | 218 | +                    if event.choices[0].delta.content is not None: | 
|  | 219 | +                        yield event.choices[0].delta.content | 
|  | 220 | +                    # Small delay to allow for cooperative multitasking | 
|  | 221 | +                    await asyncio.sleep(0) | 
|  | 222 | + | 
|  | 223 | +                # # After the stream is complete, you might want to handle tool calls here | 
|  | 224 | +                # # This is a simplification and may need to be adjusted based on your needs | 
|  | 225 | +                # if tools and collected_content.strip().startswith('{"function":'): | 
|  | 226 | +                #     # Handle tool calls (simplified example) | 
|  | 227 | +                #     tool_response = await self.handle_tool_call(collected_content, combined_history, tools) | 
|  | 228 | +                #     yield tool_response | 
|  | 229 | + | 
|  | 230 | +            return StreamingPromptResponse( | 
|  | 231 | +                content=content_generator(), | 
|  | 232 | +                raw_response=stream, | 
|  | 233 | +                error={}, | 
|  | 234 | +                usage=UsageStats( | 
|  | 235 | +                    input_tokens=0,  # These will need to be updated after streaming | 
|  | 236 | +                    output_tokens=0, | 
|  | 237 | +                    extra={}, | 
|  | 238 | +                ), | 
|  | 239 | +            ) | 
|  | 240 | +        except Exception as e: | 
|  | 241 | +            logger.error("Error generating text stream: %s", e, exc_info=True) | 
|  | 242 | +            raise e | 
|  | 243 | + | 
|  | 244 | +    async def handle_tool_call(self, collected_content, combined_history, tools): | 
|  | 245 | +        # This is a placeholder for handling tool calls in streaming context | 
|  | 246 | +        # You'll need to implement the logic to parse the tool call, execute it, | 
|  | 247 | +        # and generate a response based on the tool's output | 
|  | 248 | +        # This might involve breaking the streaming and making a new API call | 
|  | 249 | +        pass | 
0 commit comments