| 
 | 1 | +import asyncio  | 
1 | 2 | import logging  | 
 | 3 | +from typing import Any, Dict, List, Optional  | 
2 | 4 | 
 
  | 
3 | 5 | import cohere  | 
4 | 6 | from abcs.llm import LLM  | 
5 |  | -from abcs.models import PromptResponse, UsageStats  | 
 | 7 | +from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats  | 
6 | 8 | from tools.tool_manager import ToolManager  | 
7 | 9 | 
 
  | 
8 | 10 | logging.basicConfig(level=logging.INFO)  | 
@@ -121,3 +123,75 @@ def _translate_response(self, response) -> PromptResponse:  | 
121 | 123 |             )  | 
122 | 124 |             raise e  | 
123 | 125 | 
 
  | 
 | 126 | +    # https://github.com/cohere-ai/cohere-python/blob/main/src/cohere/types/streamed_chat_response.py  | 
 | 127 | +    # https://docs.cohere.com/docs/streaming#stream-events  | 
 | 128 | +    # https://docs.cohere.com/docs/streaming#example-responses  | 
 | 129 | +    async def generate_text_stream(  | 
 | 130 | +        self,  | 
 | 131 | +        prompt: str,  | 
 | 132 | +        past_messages: List[Dict[str, str]],  | 
 | 133 | +        tools: Optional[List[Dict[str, Any]]] = None,  | 
 | 134 | +        **kwargs,  | 
 | 135 | +    ) -> StreamingPromptResponse:  | 
 | 136 | +        combined_history = past_messages + [{"role": "user", "content": prompt}]  | 
 | 137 | + | 
 | 138 | +        try:  | 
 | 139 | +            combined_history = []  | 
 | 140 | +            for msg in past_messages:  | 
 | 141 | +              combined_history.append({  | 
 | 142 | +                  "role": 'CHATBOT' if msg['role'] == 'assistant' else 'USER',  | 
 | 143 | +                  "message": msg['content'],  | 
 | 144 | +              })  | 
 | 145 | +            stream = self.client.chat_stream(  | 
 | 146 | +              chat_history=combined_history,  | 
 | 147 | +              message=prompt,  | 
 | 148 | +              tools=tools,  | 
 | 149 | +              model=self.model,  | 
 | 150 | +              # perform web search before answering the question. You can also use your own custom connector.  | 
 | 151 | +              # connectors=[{"id": "web-search"}],  | 
 | 152 | +            )  | 
 | 153 | + | 
 | 154 | +            async def content_generator():  | 
 | 155 | +                for event in stream:  | 
 | 156 | +                    if isinstance(event, cohere.types.StreamedChatResponse_StreamStart):  | 
 | 157 | +                        # Message start event, we can ignore this  | 
 | 158 | +                        pass  | 
 | 159 | +                    elif isinstance(event, cohere.types.StreamedChatResponse_TextGeneration):  | 
 | 160 | +                        # This is the event that contains the actual text  | 
 | 161 | +                        if event.text:  | 
 | 162 | +                            yield event.text  | 
 | 163 | +                    elif isinstance(event, cohere.types.StreamedChatResponse_ToolCallsGeneration):  | 
 | 164 | +                        # todo: call tool  | 
 | 165 | +                        pass  | 
 | 166 | +                    elif isinstance(event, cohere.types.StreamedChatResponse_CitationGeneration):  | 
 | 167 | +                        # todo: not sure, but seems useful  | 
 | 168 | +                        pass  | 
 | 169 | +                    elif isinstance(event, cohere.types.StreamedChatResponse_ToolCallsChunk):  | 
 | 170 | +                        # todo: tool response  | 
 | 171 | +                        pass  | 
 | 172 | +                    elif isinstance(event, cohere.types.StreamedChatResponse_SearchQueriesGeneration):  | 
 | 173 | +                        pass  | 
 | 174 | +                    elif isinstance(event, cohere.types.StreamedChatResponse_SearchResults):  | 
 | 175 | +                        pass  | 
 | 176 | +                    elif isinstance(event, cohere.types.StreamedChatResponse_StreamEnd):  | 
 | 177 | +                        # Message stop event, we can ignore this  | 
 | 178 | +                        pass  | 
 | 179 | +                    # Small delay to allow for cooperative multitasking  | 
 | 180 | +                    await asyncio.sleep(0)  | 
 | 181 | + | 
 | 182 | +            return StreamingPromptResponse(  | 
 | 183 | +                content=content_generator(),  | 
 | 184 | +                raw_response=stream,  | 
 | 185 | +                error={},  | 
 | 186 | +                usage=UsageStats(  | 
 | 187 | +                    input_tokens=0,  # These will need to be updated after streaming  | 
 | 188 | +                    output_tokens=0,  | 
 | 189 | +                    extra={},  | 
 | 190 | +                ),  | 
 | 191 | +            )  | 
 | 192 | +        except Exception as e:  | 
 | 193 | +            logger.exception(f"An error occurred while streaming from Claude: {e}")  | 
 | 194 | +            raise e  | 
 | 195 | + | 
 | 196 | +    async def handle_tool_call(self, tool_calls, combined_history, tools):  | 
 | 197 | +        pass  | 
0 commit comments